DecoupledAdamW#
- class composer.optim.DecoupledAdamW(params, lr=0.001, betas=(0.9, 0.95), eps=1e-08, weight_decay=1e-05, amsgrad=False)[source]#
Adam optimizer with the weight decay term decoupled from the learning rate.
NOTE: Since weight_decay is no longer scaled by lr, you will likely want to use much smaller values for weight_decay than you would if using torch.optim.Adam or torch.optim.AdamW. In this optimizer, the value weight_decay translates exactly to: โOn every optimizer update, every weight element will be multiplied by (1.0 - weight_decay_t)โ. The term weight_decay_t will follow the same schedule as lr_t but crucially will not be scaled by lr.
Argument defaults are similar to
torch.optim.AdamW
but we make two changes: * The default forweight_decay
is changed from1e-2
->1e-5
because in DecoupledAdamW, the weight decay is decoupled and no longer scaled by the lr=1e-3. * The default forbetas
is changed from(0.9, 0.999)
to(0.9, 0.95)
to reflect community best-practices for the beta2 hyperparameter.Why use this optimizer? The standard AdamW optimizer explicitly couples the weight decay term with the learning rate. This ties the optimal value of
weight_decay
tolr
and can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.- Parameters
params (iterable) โ Iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional) โ Learning rate. Default:
1e-3
.betas (tuple, optional) โ Coefficients used for computing running averages of gradient and its square Default:
(0.9, 0.95)
.eps (float, optional) โ Term added to the denominator to improve numerical stability. Default:
1e-8
.weight_decay (float, optional) โ Decoupled weight decay factor. Default:
1e-5
.amsgrad (bool, optional) โ Enables the amsgrad variant of Adam. Default:
False
.
- static adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, *, amsgrad, beta1, beta2, lr, initial_lr, weight_decay, eps)[source]#
Functional API that performs AdamW algorithm computation with decoupled weight decay.
- Parameters
params (list) โ list of parameters to update.
grads (list) โ list of parameter gradients.
exp_avgs (list) โ list of average gradients.
exp_avg_sqs (list) โ list of average squared gradients.
max_exp_avg_sqs (list) โ list of max average squared gradients for amsgrad updates.
state_steps (list) โ list of steps taken for all parameters.
amsgrad (bool) โ Enables amsgrad variant of Adam.
beta1 (float) โ Coefficient for computing the moving average of gradient values.
beta2 (float) โ Coefficient for computing the moving average of squared gradient values.
lr (float) โ Learning rate.
initial_lr (float) โ Initial learning rate.
weight_decay (float) โ Factor for decoupled weight decay
eps (float) โ Term added to the denominator to improve numerical stability.