composer.optim.decoupled_weight_decay#
Optimizers with weight decay decoupled from the learning rate.
These optimizers are based off of Decoupled Weight Decay Regularization, which proposes this decoupling. In general, it is recommended to use these optimizers over their native PyTorch equivalents.
Classes
Adam optimizer with the weight decay term decoupled from the learning rate. |
|
SGD optimizer with the weight decay term decoupled from the learning rate. |
- class composer.optim.decoupled_weight_decay.DecoupledAdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)[source]#
Bases:
torch.optim.adamw.AdamW
Adam optimizer with the weight decay term decoupled from the learning rate.
Argument defaults are copied from
torch.optim.AdamW
.The standard AdamW optimizer explicitly couples the weight decay term with the learning rate. This ties the optimal value of
weight_decay
tolr
and can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.- Parameters
params (list) โ List of parameters to update.
lr (float, optional) โ Learning rate. Default:
1e-3
.betas (tuple, optional) โ Coefficients used for computing running averages of gradient and its square Default:
(0.9, 0.999)
.eps (float, optional) โ Term added to the denominator to improve numerical stability. Default:
1e-8
.weight_decay (float, optional) โ Decoupled weight decay factor. Default:
1e-2
.amsgrad (bool, optional) โ Enables the amsgrad variant of Adam. Default:
False
.
- static adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, *, amsgrad, beta1, beta2, lr, initial_lr, weight_decay, eps)[source]#
Functional API that performs AdamW algorithm computation with decoupled weight decay.
- Parameters
params (List[Tensor]) โ List of parameters to update.
grads (List[Tensor]) โ List of parameter gradients.
exp_avgs (List[Tensor]) โ List of average gradients.
exp_avg_sqs (List[Tensor]) โ List of average squared gradients.
max_exp_avg_sqs (List[Tensor]) โ List of max average squared gradients for amsgrad updates.
state_steps (Iterable[int]) โ List of steps taken for all parameters.
amsgrad (bool) โ Enables amsgrad variant of Adam.
beta1 (float) โ Coefficient for computing the moving average of gradient values.
beta2 (float) โ Coefficient for computing the moving average of squared gradient values.
lr (float) โ Learning rate.
initial_lr (float) โ Initial learning rate.
weight_decay (float) โ Factor for decoupled weight decay
eps (float) โ Term added to the denominator to improve numerical stability.
- class composer.optim.decoupled_weight_decay.DecoupledSGDW(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)[source]#
Bases:
torch.optim.sgd.SGD
SGD optimizer with the weight decay term decoupled from the learning rate.
Argument defaults are copied from
torch.optim.SGD
.The standard SGD optimizer couples the weight decay term with the gradient calculation. This ties the optimal value of
weight_decay
tolr
and can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.- Parameters
params (list) โ List of parameters to optimize or dicts defining parameter groups.
lr (float, optional) โ Learning rate.
momentum (int, optional) โ Momentum factor. Default:
0
.dampening (int, optional) โ Dampening factor applied to the momentum. Default:
0
.weight_decay (int, optional) โ Decoupled weight decay factor. Default:
0
.nesterov (bool, optional) โ Enables Nesterov momentum updates. Default:
False
.
- static sgdw(params, d_p_list, momentum_buffer_list, *, weight_decay, momentum, lr, initial_lr, dampening, nesterov)[source]#
Functional API that performs SGDW algorithm computation.
- Parameters
params (list) โ List of parameters to update
d_p_list (list) โ List of parameter gradients
momentum_buffer_list (list) โ List of momentum buffers
weight_decay (float) โ Decoupled weight decay factor
momentum (float) โ Momentum factor
lr (float) โ Learning rate
initial_lr (float) โ Initial learning rate
dampening (float) โ Dampening factor for momentum update
nesterov (bool) โ Enables Nesterov momentum updates