Source code for composer.models.initializers
from typing import Callable
import torch
from torch import nn as nn
from composer.utils import StringEnum
[docs]class Initializer(StringEnum):
"""Sets the initialization scheme for different layers of a PyTorch model."""
KAIMING_NORMAL = "kaiming_normal"
KAIMING_UNIFORM = "kaiming_uniform"
BN_UNIFORM = "bn_uniform"
BN_ONES = "bn_ones"
XAVIER_UNIFORM = "xavier_uniform"
def get_initializer(self) -> Callable[[torch.nn.Module], None]:
def kaiming_normal(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(w.weight)
def kaiming_uniform(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
torch.nn.init.kaiming_uniform_(w.weight)
def xavier_uniform(w: nn.Module):
if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
torch.nn.init.xavier_uniform_(w.weight)
def bn_ones(w: nn.Module):
if isinstance(w, torch.nn.BatchNorm2d):
w.weight.data = torch.ones_like(w.weight.data)
w.bias.data = torch.zeros_like(w.bias.data)
def bn_uniform(w: nn.Module):
if isinstance(w, torch.nn.BatchNorm2d):
w.weight.data = torch.rand(w.weight.data.shape)
w.bias.data = torch.zeros_like(w.bias.data)
initializer_dict = {
"kaiming_normal": kaiming_normal,
"kaiming_uniform": kaiming_uniform,
"bn_uniform": bn_uniform,
"bn_ones": bn_ones,
"xavier_uniform": xavier_uniform
}
if self.value not in initializer_dict:
raise ValueError(f"Initializer '{self.value}' not found.")
return initializer_dict[self.value]