Algorithm#

class composer.Algorithm(*args, **kwargs)[source]#

Base class for algorithms.

Algorithms are pieces of code which run at specific events (see Event) in the training loop. Algorithms modify the trainerโ€™s State, generally with the effect of improving the modelโ€™s quality or increasing the efficiency and throughput of the training loop.

Algorithms must implement the following two methods:

Method

Description

match()

returns whether the algorithm should be run given the current Event and State.

apply()

Executes the algorithmโ€™s code and makes an in-place change to the State.

abstract apply(event, state, logger)[source]#

Applies the algorithm to make an in-place change to the State.

Can optionally return an exit code to be stored in a Trace. This exit code is made accessible for debugging.

Parameters
  • event (Event) โ€“ The current event.

  • state (State) โ€“ The current state.

  • logger (Logger) โ€“ A logger to use for logging algorithm-specific metrics.

Returns

int or None โ€“ exit code that will be stored in Trace and made accessible for debugging.

property backwards_create_graph#

Whether this algorithm requires the backwards pass to be differentiable. Defaults to False.

If it returns True, create_graph=True will be passed to torch.Tensor.backward() which will result in the graph of the gradient also being constructed. This allows the computation of second order derivatives.

property find_unused_parameters#

Indicates whether this algorithm may cause some model parameters to be unused. Defaults to False.

For example, it is used to tell torch.nn.parallel.DistributedDataParallel (DDP) that some parameters will be frozen during training, and hence it should not expect gradients from them. All algorithms which do any kind of parameter freezing should override this function to return True.

abstract match(event, state)[source]#

Determines whether this algorithm should run given the current Event and State.

Examples: To only run on a specific event (e.g., on Event.BEFORE_LOSS), override match as shown below:

>>> class MyAlgorithm:
...     def match(self, event, state):
...         return event == Event.BEFORE_LOSS
>>> MyAlgorithm().match(Event.BEFORE_LOSS, state)
True

To run based on some value of a State attribute, override match as shown below:

>>> class MyAlgorithm:
...     def match(self, event, state):
...        return state.timestamp.epoch > 30
>>> MyAlgorithm().match(Event.BEFORE_LOSS, state)
False

See State for accessible attributes.

Parameters
  • event (Event) โ€“ The current event.

  • state (State) โ€“ The current state.

Returns

bool โ€“ True if this algorithm should run now.

static required_on_load()[source]#

Return True to indicate this algorithm is required when loading from a checkpoint which used it.