# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

# Written by Gihyun Park, Junyeol Lee, and Jiwon Seo

import logging
import warnings
from typing import Dict, Optional, Type

import numpy as np
import torch

from composer.algorithms.warnings import NoEffectWarning
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import module_surgery

log = logging.getLogger(__name__)

class GyroDropoutLayer(torch.nn.Module):

    def __init__(self, iters_per_epoch: int, max_epoch: int, p: float, sigma: int, tau: int):
        super(GyroDropoutLayer, self).__init__()

        self.iters_per_epoch = iters_per_epoch
        self.max_epoch = max_epoch
        self.p = p
        self.sigma = sigma
        self.tau = tau
        self.preselect_masks = torch.empty(0, 0)
        self.dropout_mask = torch.empty(0, 0)
        self.selected_masks = torch.empty(0, 0)
        self.training_step = 0
        self.iter_num = 0

    def forward(self, x):
            if self.training_step == 0:
                is_cuda_tensor = x.is_cuda

                if is_cuda_tensor:
                    self.preselect_masks = (torch.rand(self.sigma, x.shape[1]) > self.p).float().to('cuda')
                    self.preselect_masks = (torch.rand(self.sigma, x.shape[1]) > self.p).float()

                # Below simplified from: (iters_per_epoch*max_epoch*batch_size/sigma) / (batch_size/self.tau)
                self.iter_num = int(self.iters_per_epoch * self.max_epoch / self.sigma) * self.tau

            if self.training_step % self.iter_num == 0:
                pick_idx = np.random.choice(self.sigma, self.tau)
                self.selected_masks = self.preselect_masks[pick_idx]

            self.dropout_mask = torch.repeat_interleave(self.selected_masks, x.shape[0] // self.tau, dim=0)

            self.training_step += 1

            return x * self.dropout_mask * (1 / (1 - self.p))
            return x

def from_Dropout(iters_per_epoch: int, epoch: int, p: float, sigma: int, tau: int, layer: torch.nn.Module,
                 module_index: int):
    """Defines a replacement policy from a `torch.nn.Dropout` to a 'GyroDropout`"""

    return GyroDropoutLayer(iters_per_epoch, epoch, p, sigma, tau)

[docs]def apply_gyro_dropout(model: torch.nn.Module, iters_per_epoch: int, max_epoch: int, p: float, sigma: int, tau: int) -> None: """Replaces all instances of `torch.nn.Dropout` with a `GyroDropout`. By masking Dropout layer, this usually improves accuracy. """ # prepare the replacement policy and perform replacement from functools import partial policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = { torch.nn.Dropout: partial(from_Dropout, iters_per_epoch, max_epoch, p, sigma, tau) } replaced_instances = module_surgery.replace_module_classes(module=model, policies=policy) if len(replaced_instances) == 0: warnings.warn( NoEffectWarning( 'No instances of `torch.nn.Dropout` were found, and therefore, there were no modules to replace.'))'Successfully replaced {len(replaced_instances)} of dropout with a Gyro dropout.')
[docs]class GyroDropout(Algorithm): """Replaces all instances of `torch.nn.Dropout` with a `GyroDropout`. By masking Dropout layer, this usually improves accuracy. Args: p (float, optional): Float number of ratio to dropout. Default: ``0.5``. sigma (int, optional): the number of total pre-selected subnetwork Default: ``256``. tau (int, optional): the number of concurrently scheduled subnetworks in an iteration Default: ``16``. Example: .. testcode:: from composer.algorithms import GyroDropout algorithm = GyroDropout(p=0.5, sigma=256, tau=16) trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration="100ep", algorithms=[algorithm], optimizers=[optimizer] ) """ def __init__(self, p: float = 0.5, sigma: int = 256, tau: int = 16): self.p = p self.sigma = sigma self.tau = tau warnings.warn( 'GyroDropout is not implemented in a way that allows correct resumption from checkpoint, which may lead to incorrect behavior.' ) def __repr__(self) -> str: return f'{self.__class__.__name__}()' @staticmethod def required_on_load() -> bool: return True def match(self, event: Event, state: State) -> bool: del state return event == Event.FIT_START def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: del event, logger assert state.dataloader_len is not None assert state.max_duration is not None apply_gyro_dropout( model=state.model, iters_per_epoch=state.dataloader_len.value, max_epoch=state.max_duration.value, p=self.p, sigma=self.sigma, tau=self.tau, )