Source code for composer.models.initializers
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Module Initializers."""
from typing import Callable
import torch
from torch import nn as nn
from composer.utils import StringEnum
[docs]class Initializer(StringEnum):
"""Sets the initialization scheme for different layers of a PyTorch model."""
KAIMING_NORMAL = 'kaiming_normal'
KAIMING_UNIFORM = 'kaiming_uniform'
BN_UNIFORM = 'bn_uniform'
BN_ONES = 'bn_ones'
XAVIER_UNIFORM = 'xavier_uniform'
XAVIER_NORMAL = 'xavier_normal'
LINEAR_LOG_CONSTANT_BIAS = 'linear_log_constant_bias'
[docs] def get_initializer(self) -> Callable[[torch.nn.Module], None]:
"""Get the initializer function.
Returns:
(torch.nn.Module) -> None: The initializer function.
"""
def kaiming_normal(w: nn.Module):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.kaiming_normal_(w.weight)
def kaiming_uniform(w: nn.Module):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.kaiming_uniform_(w.weight)
def xavier_uniform(w: nn.Module):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.xavier_uniform_(w.weight)
def xavier_normal(w: nn.Module):
if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.xavier_normal_(w.weight)
def bn_ones(w: nn.Module):
if isinstance(w, torch.nn.BatchNorm2d):
w.weight.data = torch.ones_like(w.weight.data)
w.bias.data = torch.zeros_like(w.bias.data)
def bn_uniform(w: nn.Module):
if isinstance(w, torch.nn.BatchNorm2d):
w.weight.data = torch.rand(w.weight.data.shape)
w.bias.data = torch.zeros_like(w.bias.data)
def linear_log_constant_bias(w: nn.Module):
if isinstance(w, torch.nn.Linear):
w.bias.data = torch.ones(w.bias.shape) * -torch.log(torch.tensor(w.bias.shape[0]))
initializer_dict = {
'kaiming_normal': kaiming_normal,
'kaiming_uniform': kaiming_uniform,
'bn_uniform': bn_uniform,
'bn_ones': bn_ones,
'xavier_uniform': xavier_uniform,
'xavier_normal': xavier_normal,
'linear_log_constant_bias': linear_log_constant_bias,
}
if self.value not in initializer_dict:
raise ValueError(f"Initializer '{self.value}' not found.")
return initializer_dict[self.value]