- class composer.optim.DecoupledAdamW(params, lr=0.001, betas=(0.9, 0.95), eps=1e-08, weight_decay=1e-05, amsgrad=False)#
Adam optimizer with the weight decay term decoupled from the learning rate.
NOTE: Since weight_decay is no longer scaled by lr, you will likely want to use much smaller values for weight_decay than you would if using torch.optim.Adam or torch.optim.AdamW. In this optimizer, the value weight_decay translates exactly to: ‘On every optimizer update, every weight element will be multiplied by (1.0 - weight_decay_t)’. The term weight_decay_t will follow the same schedule as lr_t but crucially will not be scaled by lr.
Argument defaults are similar to
torch.optim.AdamWbut we make two changes: * The default for
weight_decayis changed from
1e-5because in DecoupledAdamW, the weight decay is decoupled and no longer scaled by the lr=1e-3. * The default for
betasis changed from
(0.9, 0.95)to reflect community best-practices for the beta2 hyperparameter.
Why use this optimizer? The standard AdamW optimizer explicitly couples the weight decay term with the learning rate. This ties the optimal value of
lrand can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.
params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional) – Learning rate. Default:
betas (tuple, optional) – Coefficients used for computing running averages of gradient and its square Default:
eps (float, optional) – Term added to the denominator to improve numerical stability. Default:
weight_decay (float, optional) – Decoupled weight decay factor. Default:
amsgrad (bool, optional) – Enables the amsgrad variant of Adam. Default:
- static adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, *, amsgrad, beta1, beta2, lr, initial_lr, weight_decay, eps)#
Functional API that performs AdamW algorithm computation with decoupled weight decay.
params (list) – List of parameters to update.
grads (list) – List of parameter gradients.
exp_avgs (list) – List of average gradients.
exp_avg_sqs (list) – List of average squared gradients.
max_exp_avg_sqs (list) – List of max average squared gradients for amsgrad updates.
state_steps (list) – List of steps taken for all parameters.
amsgrad (bool) – Enables amsgrad variant of Adam.
beta1 (float) – Coefficient for computing the moving average of gradient values.
beta2 (float) – Coefficient for computing the moving average of squared gradient values.
lr (float) – Learning rate.
initial_lr (float) – Initial learning rate.
weight_decay (float) – Factor for decoupled weight decay
eps (float) – Term added to the denominator to improve numerical stability.
Preprocess metrics to reduce across ranks correctly.
Performs a single optimization step.
closure (callable, optional) – A closure that reevaluates the model and returns the loss.