# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Inference-related utility functions for model export and optimizations.
Used for exporting models into various formats such ONNX, torchscript etc. and apply optimizations such as fusion.
"""
from __future__ import annotations
import contextlib
import copy
import functools
import logging
import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union
import torch
import torch.nn as nn
from composer.utils import dist
from composer.utils.checkpoint import download_checkpoint, safe_torch_load
from composer.utils.device import get_device
from composer.utils.iter_helpers import ensure_tuple
from composer.utils.misc import is_model_ddp, is_model_deepspeed, is_model_fsdp, model_eval_mode
from composer.utils.object_store import ObjectStore
from composer.utils.string_enum import StringEnum
if TYPE_CHECKING:
from composer.devices import Device
from composer.loggers import Logger
log = logging.getLogger(__name__)
__all__ = ['export_for_inference', 'ExportFormat', 'export_with_logger', 'quantize_dynamic']
Transform = Callable[[nn.Module], nn.Module]
# This is the most common way to use dynamic quantization.
# Example:
# from composer.utils import quantize_dynamic
# export_for_inference(
# ...
# transforms = [quantize_dynamic],
# ...
# )
# A user can always redefine it with extra options. This also serves as an example of what to pass to transforms.
quantize_dynamic = functools.partial(torch.quantization.quantize_dynamic, qconfig_spec={torch.nn.Linear})
def _move_sample_input_to_device(
sample_input: Optional[Union[torch.Tensor, dict, list, tuple]],
device: Device,
) -> Optional[Union[torch.Tensor, dict, list, tuple]]:
"""Handle moving sample_input of various types to a device. If possible, avoids creating copies of the input."""
output = None
if isinstance(sample_input, torch.Tensor):
output = device.tensor_to_device(sample_input)
elif isinstance(sample_input, dict):
for key, value in sample_input.items():
sample_input[key] = _move_sample_input_to_device(value, device)
output = sample_input
elif isinstance(sample_input, list):
for i in range(len(sample_input)):
sample_input[i] = _move_sample_input_to_device(sample_input[i], device)
output = sample_input
elif isinstance(sample_input, tuple):
new_tuple = []
for tuple_item in sample_input:
new_tuple.append(_move_sample_input_to_device(tuple_item, device))
output = tuple(new_tuple)
return output
[docs]def export_for_inference(
model: nn.Module,
save_format: Union[str, ExportFormat],
save_path: str,
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
dynamic_axes: Optional[Any] = None,
surgery_algs: Optional[Union[Callable[[nn.Module], nn.Module], Sequence[Callable[[nn.Module], nn.Module]]]] = None,
transforms: Optional[Sequence[Transform]] = None,
onnx_opset_version: Optional[int] = None,
load_path: Optional[str] = None,
load_object_store: Optional[ObjectStore] = None,
load_strict: bool = False,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
) -> None:
"""Export a model for inference.
Args:
model (nn.Module): An instance of nn.Module. Please note that model is not modified inplace.
Instead, export-related transformations are applied to a copy of the model.
save_format (Union[str, ExportFormat]): Format to export to. Either ``"torchscript"`` or ``"onnx"``.
save_path: (str): The path for storing the exported model. It can be a path to a file on the local disk,
a URL, or if ``save_object_store`` is set, the object name
in a cloud bucket. For example, ``my_run/exported_model``.
save_object_store (ObjectStore, optional): If the ``save_path`` is in an object name in a cloud bucket
(i.e. AWS S3 or Google Cloud Storage), an instance of
:class:`~.ObjectStore` which will be used
to store the exported model. Set this to ``None`` if ``save_path`` is a local filepath.
(default: ``None``)
sample_input (Any, optional): Example model inputs used for tracing. This is needed for "onnx" export.
The ``sample_input`` need not match the batch size you intend to use for inference. However, the model
should accept the ``sample_input`` as is. (default: ``None``)
dynamic_axes (Any, optional): Dictionary specifying the axes of input/output tensors as dynamic. May be required
for exporting models using older versions of PyTorch when types cannot be inferred.
surgery_algs (Union[Callable, Sequence[Callable]], optional): Algorithms that should be applied to the model
before loading a checkpoint. Each should be callable that takes a model and returns None.
``surgery_algs`` are applied before ``transforms``. (default: ``None``)
transforms (Sequence[Transform], optional): transformations (usually optimizations) that should
be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
``transforms`` are applied after ``surgery_algs``. (default: ``None``)
onnx_opset_version (int, optional): The opset version ONNX should use for exporting. Only used if save_format is ``"onnx"``.
Defaults to Pytorch's default torch.onnx.export opset version, which changes by PyTorch version. (default: ``None``)
load_path (str): The path to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if ``load_object_store`` is set, the object name
for a checkpoint in a cloud bucket. For example, run_name/checkpoints/ep0-ba4-rank0. (default: ``None``)
load_object_store (ObjectStore, optional): If the ``load_path`` is in an object name in a cloud bucket
(i.e. AWS S3 or Google Cloud Storage), an instance of
:class:`~.ObjectStore` which will be used to retreive the checkpoint.
Otherwise, if the checkpoint is a local filepath, set to ``None``. (default: ``None``)
load_strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should
perfectly match the keys in the model instance. (default: ``False``)
input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
to ``None``, it defaults to ``["output"]``.
Returns:
None
"""
save_format = ExportFormat(save_format)
if is_model_deepspeed(model):
raise ValueError(f'Exporting for deepspeed models is currently not supported.')
if is_model_ddp(model):
raise ValueError(
f'Directly exporting a DistributedDataParallel model is not supported. Export the module instead.',
)
if is_model_fsdp(model):
raise ValueError(
'Directly exporting a FSDP wrapped module is not supported as the model is deepcopied to avoid '
'side-effects, and FSDP does not support deepcopying. To export the model, load it without FSDP '
'wrapping.',
)
# Only rank0 exports the model
if dist.get_global_rank() != 0:
return
# Make a copy of the model so that we don't modify the original model
model = copy.deepcopy(model)
# Make a copy of the sample input so that we don't modify the original sample input
sample_input = copy.deepcopy(sample_input)
# Move model and sample input to CPU for export
cpu = get_device('cpu')
cpu.module_to_device(model)
if sample_input is not None:
sample_input = ensure_tuple(sample_input)
sample_input = _move_sample_input_to_device(sample_input, cpu)
# Apply surgery algorithms in the given order
for alg in ensure_tuple(surgery_algs):
alg(model)
if load_path is not None:
# download checkpoint and load weights only
log.debug('Loading checkpoint at %s', load_path)
with tempfile.TemporaryDirectory() as tempdir:
composer_states_filepath, _, _ = download_checkpoint(
path=load_path,
node_checkpoint_folder=tempdir,
object_store=load_object_store,
progress_bar=True,
)
state_dict = safe_torch_load(composer_states_filepath)
missing_keys, unexpected_keys = model.load_state_dict(state_dict['state']['model'], strict=load_strict)
if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if len(unexpected_keys) > 0:
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
with model_eval_mode(model):
# Apply transformations (i.e., inference optimizations) in the given order
for transform in ensure_tuple(transforms):
model = transform(model)
is_remote_store = save_object_store is not None
tempdir_ctx = tempfile.TemporaryDirectory() if is_remote_store else contextlib.nullcontext(None)
with tempdir_ctx as tempdir:
if is_remote_store:
local_save_path = os.path.join(str(tempdir), 'model.export')
else:
local_save_path = save_path
if save_format == ExportFormat.TORCHSCRIPT:
export_model = None
try:
export_model = torch.jit.script(model)
except Exception:
if sample_input is not None:
log.warning('Scripting with torch.jit.script failed. Trying torch.jit.trace!')
export_model = torch.jit.trace(model, sample_input)
else:
log.warning(
'Scripting with torch.jit.script failed and sample inputs are not provided for tracing '
'with torch.jit.trace',
exc_info=True,
)
if export_model is not None:
torch.jit.save(export_model, local_save_path)
else:
raise RuntimeError('Scritping and tracing failed! No model is getting exported.')
if save_format == ExportFormat.ONNX:
if sample_input is None:
raise ValueError(f'sample_input argument is required for onnx export')
# assert statement for pyright error: Cannot access member "keys" for type "Tensor"
assert isinstance(sample_input, tuple)
if input_names is None:
input_names = []
# Extract input names from sample_input if it contains dicts
for i in range(len(sample_input)):
if isinstance(sample_input[i], dict):
input_names += list(sample_input[i].keys())
# Default input name if no dict present
if input_names == []:
input_names = ['input']
if output_names is None:
output_names = ['output']
torch.onnx.export(
model,
sample_input,
local_save_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
)
# upload if required.
if is_remote_store:
save_object_store.upload_object(save_path, local_save_path)
[docs]def export_with_logger(
model: nn.Module,
save_format: Union[str, ExportFormat],
save_path: str,
logger: Logger,
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
transforms: Optional[Sequence[Transform]] = None,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
) -> None:
"""Helper method for exporting a model for inference.
Exports the model to:
1) save_object_store, if one is provided,
2) logger.upload_file(save_path), if (1) does not apply and the logger has a destination that supports file uploading,
3) locally, if (1) and (2) do not apply.
Args:
model (nn.Module): An instance of nn.Module. Please note that model is not modified inplace.
Instead, export-related transformations are applied to a copy of the model.
save_format (Union[str, ExportFormat]): Format to export to. Either ``"torchscript"`` or ``"onnx"``.
save_path: (str): The path for storing the exported model. It can be a path to a file on the local disk,
a URL, or if ``save_object_store`` is set, the object name
in a cloud bucket. For example, ``my_run/exported_model``.
logger (Logger): If this logger has a destination that supports file uploading, and save_object_store
is not provided, this logger is used to export the model.
save_object_store (ObjectStore, optional): If the ``save_path`` is in an object name in a cloud bucket
(i.e. AWS S3 or Google Cloud Storage), an instance of
:class:`~.ObjectStore` which will be used
to store the exported model. Set this to ``None`` if the logger should be used to export the model or
if ``save_path`` is a local filepath.
(default: ``None``)
sample_input (Any, optional): Example model inputs used for tracing. This is needed for "onnx" export.
The ``sample_input`` need not match the batch size you intend to use for inference. However, the model
should accept the ``sample_input`` as is. (default: ``None``)
transforms (Sequence[Transform], optional): transformations (usually optimizations) that should
be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
``transforms`` are applied after ``surgery_algs``. (default: ``None``)
input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
to ``None``, it defaults to ``["output"]``.
Returns:
None
"""
if save_object_store == None and logger.has_file_upload_destination():
with tempfile.TemporaryDirectory() as tmpdir:
temp_local_save_path = os.path.join(str(tmpdir), f'model')
export_for_inference(
model=model,
save_format=save_format,
save_path=temp_local_save_path,
sample_input=sample_input,
transforms=transforms,
input_names=input_names,
output_names=output_names,
)
logger.upload_file(remote_file_name=save_path, file_path=temp_local_save_path)
else:
export_for_inference(
model=model,
save_format=save_format,
save_path=save_path,
save_object_store=save_object_store,
sample_input=sample_input,
transforms=transforms,
input_names=input_names,
output_names=output_names,
)