[docs]defapply_factorization(model:torch.nn.Module,factorize_convs:bool=True,factorize_linears:bool=True,min_channels:int=512,latent_channels:Union[int,float]=0.25,min_features:int=512,latent_features:Union[int,float]=0.25,optimizers:Optional[Union[Optimizer,Sequence[Optimizer]]]=None,)->None:"""Replaces :class:`torch.nn.Linear` and :class:`torch.nn.Conv2d` modules with :class:`.FactorizedLinear` and :class:`.FactorizedConv2d` modules. Factorized modules replace one full-rank operation with a sequence of two lower-rank operations. When the rank is low enough, this can save computation, at the cost of expressive power. See :class:`.Factorize` for details. Args: model (torch.nn.Module): the model to modify in-place. factorize_convs (bool, optional): whether to try factorizing :class:`torch.nn.Conv2d` modules. Default: ``True``. factorize_linears (bool, optional): whether to try factorizing :class:`torch.nn.Linear` modules. Default: ``True``. min_channels (int, optional): if a :class:`torch.nn.Conv2d` module does not have at least this many input and output channels, it will be ignored. Modules with few channels are unlikely to be accelerated by factorization due to poor hardware utilization. Default: ``512``. latent_channels (int | float, optional): number of latent channels to use in factorized convolutions. Can be specified as either an integer > 1 or as a float within ``[0, 1)``. In the latter case, the value is interpreted as a fraction of ``min(in_channels, out_channels)`` for each :class:`torch.nn.Conv2d` module, and is converted to the equivalent integer value, with a minimum of 1. Default: ``0.25``. min_features (int, optional): if a :class:`torch.nn.Linear` module does not have at least this many input and output features, it will be ignored. Modules with few features are unlikely to be accelerated by factorization due to poor hardware utilization. Default: ``512``. latent_features (int | float, optional): size of the latent space for factorized linear modules. Can be specified as either an integer > 1 or as a float within ``[0, 0.5)``. In the latter case, the value is interpreted as a fraction of ``min(in_features, out_features)`` for each :class:`torch.nn.Linear` module, and is converted to the equivalent integer value, with a minimum of 1. Default: ``0.25``. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): Existing optimizers bound to ``model.parameters()``. All optimizers that have already been constructed with ``model.parameters()`` must be specified here so that they will optimize the correct parameters. If the optimizer(s) are constructed *after* calling this function, then it is safe to omit this parameter. These optimizers will see the correct model parameters. Example: .. testcode:: import composer.functional as cf from torchvision import models model = models.resnet50() cf.apply_factorization(model) """iffactorize_convs:_factorize_conv2d_modules(model,min_channels=min_channels,latent_channels=latent_channels,optimizers=optimizers,)iffactorize_linears:_factorize_linear_modules(model,min_features=min_features,latent_features=latent_features,optimizers=optimizers,)
[docs]classFactorize(Algorithm):"""Decomposes linear operators into pairs of smaller linear operators. Specifically, this algorithm replaces :class:`torch.nn.Conv2d` and :class:`torch.nn.Linear` modules with :class:`.FactorizedConv2d` and :class:`.FactorizedLinear` modules. The replacement is only performed if doing so would reduce the number of multiply-adds used to compute each module's output. For linear layers and pointwise convolutions, this means that the factorization must use an intermediate rank of less than half the input and output ranks, since it must perform two operations instead of one. For convolutions with kernel sizes greater than 1, the threshold for factorization being worthwhile varies with kernel size. Larger kernels allow larger intermediate ranks. See :func:`.factorize_matrix` and :func:`.factorize_conv2d` for more information about the factorization process. See :class:`.FactorizedConv2d` and :class:`.FactorizedLinear` for more information about the factorized modules used to replace the original modules. Runs on :attr:`.Event.INIT`. Args: factorize_convs (bool): whether to try factorizing :class:`torch.nn.Conv2d` modules. Default: ``True``. factorize_linears (bool): whether to try factorizing :class:`torch.nn.Linear` modules. Default: ``True``. min_channels (int): if a :class:`torch.nn.Conv2d` module does not have at least this many input and output channels, it will be ignored. Modules with few channels are unlikely to be accelerated by factorization due to poor hardware utilization. Default: ``256``. latent_channels (int, float): number of latent channels to use in factorized convolutions. Can be specified as either an integer > 1 or as a float within ``[0, 1)``. In the latter case, the value is interpreted as a fraction of ``min(in_channels, out_channels)`` for each :class:`torch.nn.Conv2d` module, and is converted to the equivalent integer value, with a minimum of 1. Default: ``0.25``. min_features (int): if a :class:`torch.nn.Linear` module does not have at least this many input and output features, it will be ignored. Modules with few features are unlikely to be accelerated by factorization due to poor hardware utilization. Default: ``256``. latent_features (int, float): size of the latent space for factorized linear modules. Can be specified as either an integer > 1 or as a float within ``[0, 0.5)``. In the latter case, the value is interpreted as a fraction of ``min(in_features, out_features)`` for each :class:`torch.nn.Linear` module and is converted to the equivalent integer value, with a minimum of 1. Default: ``128``. """def__init__(self,factorize_convs:bool=True,factorize_linears:bool=True,min_channels:int=256,latent_channels:Union[int,float]=0.25,min_features:int=256,latent_features:Union[int,float]=128,):self.factorize_convs=factorize_convsself.factorize_linears=factorize_linearsself.min_channels=min_channelsself.latent_channels=latent_channelsself.min_features=min_featuresself.latent_features=latent_featuresdef__repr__(self)->str:returnf'{self.__class__.__name__}(factorize_convs={self.factorize_convs},factorize_linears={self.factorize_linears},min_channels={self.min_channels},latent_channels={self.latent_channels},min_features={self.min_features},latent_features={self.latent_features})'@staticmethoddefrequired_on_load()->bool:returnTruedefmatch(self,event:Event,state:State)->bool:returnevent==Event.INITdefapply(self,event:Event,state:State,logger:Logger)->Optional[int]:assertstate.modelisnotNone,'Model must be part of state!'apply_factorization(model=state.model,factorize_convs=self.factorize_convs,factorize_linears=self.factorize_linears,min_channels=self.min_channels,latent_channels=self.latent_channels,min_features=self.min_features,latent_features=self.latent_features,optimizers=state.optimizers,)ifself.factorize_convs:num_factorized=module_surgery.count_module_instances(state.model,FactorizedConv2d)logger.log_hyperparameters({LOG_NUM_CONV2D_REPLACEMENTS_KEY:num_factorized,})ifself.factorize_linears:num_factorized=module_surgery.count_module_instances(state.model,FactorizedLinear)logger.log_hyperparameters({LOG_NUM_LINEAR_REPLACEMENTS_KEY:num_factorized,})
def_python_log_surgery_result(model:torch.nn.Module,new_class:Type[torch.nn.Module]):num_replaced_modules=module_surgery.count_module_instances(model,new_class)log.info(f'Applied factorization to model {model.__class__.__name__}. '+f'Model now has {num_replaced_modules}{new_class.__name__} modules',)def_factorize_conv2d_modules(model:torch.nn.Module,min_channels:int=512,latent_channels:Union[int,float]=0.25,optimizers:Optional[Union[Optimizer,Sequence[Optimizer]]]=None,):"""Replaces :class:`torch.nn.Conv2d` modules in ``model`` with :class:`.FactorizedConv2d` modules. See :class:`.Factorize` for details. """def_maybe_replace_conv2d(module:torch.nn.Module,module_index:int)->Optional[torch.nn.Module]:module=cast(torch.nn.Conv2d,module)wide_enough=min(module.out_channels,module.in_channels)>=min_channelsiffactorizing_could_speedup(module,latent_channels)andwide_enough:returnFactorizedConv2d.from_conv2d(module,module_index,latent_channels=latent_channels)returnNone# not enough rank reduction to be worth itret=module_surgery.replace_module_classes(model,optimizers=optimizers,policies={torch.nn.Conv2d:_maybe_replace_conv2d},)_python_log_surgery_result(model,FactorizedConv2d)returnretdef_factorize_linear_modules(model:torch.nn.Module,min_features:int=512,latent_features:Union[int,float]=0.25,optimizers:Optional[Union[Optimizer,Sequence[Optimizer]]]=None,):"""Replaces :class:`torch.nn.Linear` modules in ``model`` with :class:`.FactorizedLinear` modules. See :class:`.Factorize` for details. """def_maybe_replace_linear(module:torch.nn.Module,module_index:int)->Optional[torch.nn.Module]:module=cast(torch.nn.Linear,module)wide_enough=min(module.in_features,module.out_features)>=min_featuresiffactorizing_could_speedup(module,latent_features)andwide_enough:returnFactorizedLinear.from_linear(module,module_index,latent_features=latent_features)returnNone# not enough rank reduction to be worth itret=module_surgery.replace_module_classes(model,optimizers=optimizers,policies={torch.nn.Linear:_maybe_replace_linear},)_python_log_surgery_result(model,FactorizedLinear)returnret