🚌 Welcome Tour#
Welcome to MosaicML Composer! This guide will walk you through the basics of how the Composer trainer works and how it interacts with our methods libary. This guide will assume that you’ve already gone through the installation instructions.
Our First Method!#
We’re going to explore MixUp (Zhang et al, 2017), a simple regularization technique that tends to improve the accuracy of image classification models.
MixUp operates by modifying the batches of data used to train the model; instead of training on individual samples, we train on convex combinations of samples. Thus, our implementation of the MixUp algorithm needs to be able to modify batches of training data after they are loaded from the dataloader and before they are passed into the forward pass of the model.
For more information on MixUp, see 🥣 MixUp in our methods library.
So how can we use MixUp within a trainer?
A Simple Training Loop#
A very simple Pytorch training loop looks something like the following:
for epoch in range(NUM_EPOCHS): for inputs, targets in dataloader: outputs = model.forward(inputs) loss = model.loss(outputs, targets) loss.backward() optimizer.step() optimizer.zero_grad()
MixUp needs to modify
targets after they are loaded from the dataloader but before the inputs are
passed to the forward pass of the model. So one possibility is we could make use of our functional API to modify our
from composer import functional as cf for epoch in range(NUM_EPOCHS): for inputs, targets in dataloader: inputs, targets_perm, mixing = cf.mixup_batch(inputs, targets, alpha=0.2) outputs = model.forward(inputs) loss = (1 - mixing) * model.loss(outputs, targets) + mixing * model.loss(outputs, targets_perm) loss.backward() optimizer.step() optimizer.zero_grad()
This works and is recommended if you want to quickly modify an existing training loop to use our implementation of MixUp! However, the goal of Composer is to be able to rapidly experiment with different combinations of algorithms. Our methods library contains over 20 different methods to experiment with, and it would be unwieldy to have to add conditional logic to the trainer for enabling/disabling each new method. This is where the composer trainer comes in.
Events, Engines, and State#
The core principle of the Composer trainer is to avoid the need to introduce algorithm-specific logic to the trainer
by instead relying on callbacks tied to events. Events describe specific stages of the training lifecycle, such as
BEFORE_FORWARD. This is based on the two-way callback system from (Howard et al, 2020).
We could add events to our training loop as follows:
# <INIT> # <FIT_START> for epoch in range(NUM_EPOCHS): # <EPOCH_START> for inputs, targets in dataloader: # <AFTER_DATALOADER> # <BATCH_START> # <BEFORE_FORWARD> outputs = model.forward(inputs) # <AFTER_FORWARD> # <BEFORE_LOSS> loss = model.loss(outputs, targets) # <AFTER_LOSS> # <BEFORE_BACKWARD> loss.backward() # <AFTER_BACKWARD> optimizer.step() optimizer.zero_grad() # <BATCH_END> # <EPOCH_END>
Now we need a way to tie events to algorithms, so that we know which algorithms to run and when to run them.
This is the purpose of the
Engine is initialized with a
list of algorithms to run and provides a
composer.core.Engine.run_event() method that the trainer can call to
execute algorithms for the given event. The
Engine is also responsible for handling potential
conflicts between multiple algorithms.
One piece is missing. Algorithms are no longer running from within the body of the training loop, but they still need
to be able to modify the training loop’s state. For this, we introduce
State which stores all
objects relevant to training that algorithms need access to. The
Engine is initialized with a
reference to the
State and passes it to algorithms when it invokes them.
Finally, to be compatible with the
Engine, algorithms need to implement two methods:
apply(). For MixUp, these methods can be very
class MixUp(Algorithm): def match(self, event: Event, state: State) -> bool: """Determine whether the algorithm should run on a given event.""" return event in [Event.AFTER_DATALOADER, Event.AFTER_LOSS] def apply(self, event: Event, state: State, logger: Logger) -> None: """Run the algorithm by modifying the State.""" input, target = state.batch if event == Event.AFTER_DATALOADER: new_input, self.permuted_target, self.mixing = mixup_batch(input, target, alpha=0.2) state.batch = (new_input, target) if event == Event.AFTER_LOSS: modified_batch = (input, self.permuted_target) new_loss = state.model.loss(state.outputs, modified_batch) state.loss *= (1 - self.mixing) state.loss += self.mixing * new_loss
Putting all the pieces together, our trainer looks something like this:
from composer import Time state = State( model=your_model, # ComposerModel max_duration=Time(10, 'epoch'), rank_zero_seed=0, dataloader=your_dataloader # torch.utils.DataLoader, ) engine = Engine(state=state, algorithms=[MixUp()]) engine.run_event("init") engine.run_event("fit_start") for epoch in range(NUM_EPOCHS): engine.run_event("epoch_start") for state.inputs, state.targets in dataloader: engine.run_event("after_dataloader") engine.run_event("batch_start") engine.run_event("before_forward") state.outputs = state.model.forward(state.inputs) engine.run_event("after_forward") engine.run_event("before_loss") state.loss = state.model.loss(state.outputs, state.targets) engine.run_event("after_loss") engine.run_event("before_backward") state.loss.backward() engine.run_event("after_backward") state.optimizers.step() state.schedulers.step() engine.run_event("batch_end") engine.run_event("epoch_end")
That’s it! Mixup will automatically run on
"after_loss". And thanks to all of the events being present in the training loop, we can easily start using new algorithms as well!
For more information on events, state, and engines, check out
For advanced experimentation, we recommend using the Composer Trainer. The trainer not only takes care of all the state management and event callbacks from above, but also adds advanced features like hyperparameter management, gradient accumulation, and closure support.