Source code for composer.utils.misc
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Miscellaneous Helpers."""
import socket
from contextlib import contextmanager
from typing import Type
import torch
from packaging import version
from torch.nn.parallel import DistributedDataParallel
__all__ = [
'is_model_deepspeed', 'is_model_fsdp', 'is_notebook', 'warning_on_one_line', 'get_free_tcp_port', 'model_eval_mode'
]
[docs]def is_model_deepspeed(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`~deepspeed.DeepSpeedEngine`."""
try:
import deepspeed
except ImportError:
return False
else:
return isinstance(model, deepspeed.DeepSpeedEngine)
def is_model_ddp(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`.DistributedDataParallel`."""
return isinstance(model, DistributedDataParallel)
[docs]def is_model_fsdp(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`.FullyShardedDataParallel`."""
try:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if isinstance(model, FSDP):
return True
# Check if model is wrapped with FSDP
for _, obj in model.named_children():
if isinstance(obj, FSDP):
return True
return False
except ImportError:
return False
[docs]def is_notebook():
"""Whether Composer is running in a IPython/Jupyter Notebook."""
try:
__IPYTHON__ #type: ignore
return True
except NameError:
return False
def warning_on_one_line(message: str, category: Type[Warning], filename: str, lineno: int, file=None, line=None):
"""Force Python warnings to consolidate into one line."""
# From https://stackoverflow.com/questions/26430861/make-pythons-warnings-warn-not-mention-itself
return f'{category.__name__}: {message} (source: {filename}:{lineno})\n'
[docs]def get_free_tcp_port() -> int:
"""Get free socket port to use as MASTER_PORT."""
# from https://www.programcreek.com/python/?CodeExample=get+free+port
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.bind(('', 0))
_, port = tcp.getsockname()
tcp.close()
return port
[docs]@contextmanager
def model_eval_mode(model: torch.nn.Module):
"""Set model.eval() for context duration, restoring model status at end."""
is_training = model.training
try:
model.eval()
yield
finally:
model.train(mode=is_training)
[docs]def using_torch_2() -> bool:
"""Check the PyTorch version and compared it with version 2.0.0.
Returns:
bool: Return True if current version is greater than or equal to 2.0.0 else False
"""
return version.parse(torch.__version__) >= version.parse('2.0.0')