class composer.optim.DecoupledAdamW(params, lr=0.001, betas=(0.9, 0.95), eps=1e-08, weight_decay=1e-05, amsgrad=False)[source]#

Adam optimizer with the weight decay term decoupled from the learning rate.

NOTE: Since weight_decay is no longer scaled by lr, you will likely want to use much smaller values for weight_decay than you would if using torch.optim.Adam or torch.optim.AdamW. In this optimizer, the value weight_decay translates exactly to: โ€˜On every optimizer update, every weight element will be multiplied by (1.0 - weight_decay_t)โ€™. The term weight_decay_t will follow the same schedule as lr_t but crucially will not be scaled by lr.

Argument defaults are similar to torch.optim.AdamW but we make two changes: * The default for weight_decay is changed from 1e-2 -> 1e-5 because in DecoupledAdamW, the weight decay is decoupled and no longer scaled by the lr=1e-3. * The default for betas is changed from (0.9, 0.999) to (0.9, 0.95) to reflect community best-practices for the beta2 hyperparameter.

Why use this optimizer? The standard AdamW optimizer explicitly couples the weight decay term with the learning rate. This ties the optimal value of weight_decay to lr and can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.

  • params (iterable) โ€“ Iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) โ€“ Learning rate. Default: 1e-3.

  • betas (tuple, optional) โ€“ Coefficients used for computing running averages of gradient and its square Default: (0.9, 0.95).

  • eps (float, optional) โ€“ Term added to the denominator to improve numerical stability. Default: 1e-8.

  • weight_decay (float, optional) โ€“ Decoupled weight decay factor. Default: 1e-5.

  • amsgrad (bool, optional) โ€“ Enables the amsgrad variant of Adam. Default: False.

static adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, *, amsgrad, beta1, beta2, lr, initial_lr, weight_decay, eps)[source]#

Functional API that performs AdamW algorithm computation with decoupled weight decay.

  • params (list) โ€“ list of parameters to update.

  • grads (list) โ€“ list of parameter gradients.

  • exp_avgs (list) โ€“ list of average gradients.

  • exp_avg_sqs (list) โ€“ list of average squared gradients.

  • max_exp_avg_sqs (list) โ€“ list of max average squared gradients for amsgrad updates.

  • state_steps (list) โ€“ list of steps taken for all parameters.

  • amsgrad (bool) โ€“ Enables amsgrad variant of Adam.

  • beta1 (float) โ€“ Coefficient for computing the moving average of gradient values.

  • beta2 (float) โ€“ Coefficient for computing the moving average of squared gradient values.

  • lr (float) โ€“ Learning rate.

  • initial_lr (float) โ€“ Initial learning rate.

  • weight_decay (float) โ€“ Factor for decoupled weight decay

  • eps (float) โ€“ Term added to the denominator to improve numerical stability.


Preprocess metrics to reduce across ranks correctly.


Performs a single optimization step.


closure (callable, optional) โ€“ A closure that reevaluates the model and returns the loss.