composer.core.algorithm#
Base class for algorithms that improve modelโs quality or efficiency.
Classes
Base class for algorithms. |
- class composer.core.algorithm.Algorithm[source]#
Bases:
composer.core.serializable.Serializable
,abc.ABC
Base class for algorithms.
Algorithms are pieces of code which run at specific events (see
Event
) in the training loop. Algorithms modify the trainerโsState
, generally with the effect of improving the modelโs quality, or increasing the efficiency and throughput of the training loop.- Algorithms must implement the following two methods:
- abstract apply(event, state, logger)[source]#
Applies the algorithm to make an in-place change to the
State
.Can optionally return an exit code to be stored in a
Trace
and this exit code is made accessible for debugging.
- property backwards_create_graph#
Return True to indicate that this algorithm requires a second derivative to be computed. Defaults to False.
If it returns True,
create_graph=True
will be passed totorch.Tensor.backward()
which will result in the graph of the gradient also being constructed. This allows to compute second order derivative.
- property find_unused_parameters#
Return True to indicate that the effect of this algorithm may cause some model parameters to be unused. Defaults to False.
For example, it is used to tell
torch.nn.parallel.DistributedDataParallel
(DDP) that some parameters will be frozen during training and hence it should not expect gradients from them. All algorithms which do any kind of parameter freezing should override this function to return True.Note
DeepSpeed integration with this function returing True is not tested. It may not work as expected.
- abstract match(event, state)[source]#
Determines whether this algorithm should run given the current
Event
andState
.Examples:
To only run on a specific event (e.g., on
BEFORE_LOSS
), override match as shown below:>>> class MyAlgorithm: ... def match(self, event, state): ... return event == Event.BEFORE_LOSS >>> MyAlgorithm().match(Event.BEFORE_LOSS, state) True
To run based on some value of a
State
attribute, override match as shown below:>>> class MyAlgorithm: ... def match(self, event, state): ... return state.timer.epoch > 30 >>> MyAlgorithm().match(Event.BEFORE_LOSS, state) False
See
State
for accessible attributes.