Source code for composer.utils.misc
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Miscellaneous Helpers."""
import socket
from contextlib import contextmanager
from typing import Type
import torch
from torch.nn.parallel import DistributedDataParallel
__all__ = [
'is_model_deepspeed', 'is_model_fsdp', 'is_notebook', 'warning_on_one_line', 'get_free_tcp_port', 'model_eval_mode'
]
[docs]def is_model_deepspeed(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`~deepspeed.DeepSpeedEngine`."""
try:
import deepspeed
except ImportError:
return False
else:
return isinstance(model, deepspeed.DeepSpeedEngine)
def is_model_ddp(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`.DistributedDataParallel`."""
return isinstance(model, DistributedDataParallel)
[docs]def is_model_fsdp(model: torch.nn.Module) -> bool:
"""Whether ``model`` is an instance of a :class:`.FullyShardedDataParallel`."""
try:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
is_fsdp = False
# Check if model is wrapped with FSDP
for _, obj in model.named_children():
if isinstance(obj, FSDP):
is_fsdp = True
return is_fsdp
except ImportError:
return False
[docs]def is_notebook():
"""Whether Composer is running in a IPython/Jupyter Notebook."""
try:
__IPYTHON__ #type: ignore
return True
except NameError:
return False
def warning_on_one_line(message: str, category: Type[Warning], filename: str, lineno: int, file=None, line=None):
"""Force Python warnings to consolidate into one line."""
# From https://stackoverflow.com/questions/26430861/make-pythons-warnings-warn-not-mention-itself
return f'{category.__name__}: {message} (source: {filename}:{lineno})\n'
[docs]def get_free_tcp_port() -> int:
"""Get free socket port to use as MASTER_PORT."""
# from https://www.programcreek.com/python/?CodeExample=get+free+port
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.bind(('', 0))
_, port = tcp.getsockname()
tcp.close()
return port
[docs]@contextmanager
def model_eval_mode(model: torch.nn.Module):
"""Set model.eval() for context duration, restoring model status at end."""
is_training = model.training
try:
model.eval()
yield
finally:
model.train(mode=is_training)