Source code for composer.callbacks.oom_observer

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Generate a memory snapshot during an OutOfMemory exception."""
from __future__ import annotations

import dataclasses
import logging
import os
import pickle
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import torch.cuda
from packaging import version

from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri

log = logging.getLogger(__name__)

__all__ = ['OOMObserver']


@dataclass(frozen=True)
class SnapshotFileNameConfig:
    """Configuration for the file names of the memory snapshot visualizations."""
    snapshot_file: str
    trace_plot_file: str
    segment_plot_file: str
    segment_flamegraph_file: str
    memory_flamegraph_file: str

    @classmethod
    def from_file_name(cls, filename: str) -> 'SnapshotFileNameConfig':
        return cls(
            snapshot_file=filename + '_snapshot.pickle',
            trace_plot_file=filename + '_trace_plot.html',
            segment_plot_file=filename + '_segment_plot.html',
            segment_flamegraph_file=filename + '_segment_flamegraph.svg',
            memory_flamegraph_file=filename + '_memory_flamegraph.svg',
        )

    def list_filenames(self) -> List[str]:
        return [getattr(self, field.name) for field in dataclasses.fields(self)]


[docs]class OOMObserver(Callback): """Generate visualizations of the state of allocated memory during an OutOfMemory exception. This callback registers an observer with the allocator that will be called everytime it is about to raise an OutOfMemoryError before any memory has been release while unwinding the exception. OOMObserver is attached to the Trainer at init stage. The visualizations include a snapshot of the memory state, a trace plot, a segment plot, a segment flamegraph, and a memory flamegraph. Example: .. doctest:: >>> from composer import Trainer >>> from composer.callbacks import OOMObserver >>> # constructing trainer object with this callback >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... optimizers=optimizer, ... max_duration="1ep", ... callbacks=[OOMObserver()], ... ) .. note:: OOMObserver is only supported for GPU devices. Args: max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000. folder (str, optional): A format string describing the folder containing the memory visualization files. Defaults to ``'{{run_name}}/torch_traces'``. filename (str, optional): A format string describing the prefix used to name the memory visualization files. Defaults to ``'rank{{rank}}_oom'``. remote_file_name (str, optional): A format string describing the prefix for the memory visualization remote file name. Defaults to ``'{{run_name}}/oom_traces/rank{{rank}}_oom'``. Whenever a trace file is saved, it is also uploaded as a file according to this format string. The same format variables as for ``filename`` are available. .. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading. Leading slashes (``'/'``) will be stripped. To disable uploading trace files, set this parameter to ``None``. overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False. If False, then the trace folder as determined by ``folder`` must be empty. """ def __init__( self, max_entries: int = 100000, folder: str = '{run_name}/torch_traces', filename: str = 'rank{rank}_oom', remote_file_name: Optional[str] = '{run_name}/oom_traces/rank{rank}_oom', overwrite: bool = False, ) -> None: self.max_entries = max_entries self.folder = folder self.folder_name = None self.filename = filename self.remote_file_name = remote_file_name self.overwrite = overwrite if remote_file_name: self.remote_file_name = remote_file_name _, _, self.remote_path_in_bucket = parse_uri(remote_file_name) else: self.remote_path_in_bucket = None if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore # OOMObserver is only supported in torch v2.1.0 or higher self._enabled = True else: self._enabled = False warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.') self.filename_config: Optional[SnapshotFileNameConfig] = None def init(self, state: State, logger: Logger) -> None: if not self._enabled: return # Not relying on `torch.cuda.is_available()` since the model could be on CPU. model_device = next(state.model.parameters()).device if model_device.type not in ('cuda', 'meta'): warnings.warn( f'OOMObserver only works on CUDA devices, but the model is on {model_device.type}. Disabling OOMObserver.', ) self._enabled = False else: self.folder_name = format_name_with_dist(self.folder, state.run_name) os.makedirs(self.folder_name, exist_ok=True) if not self.overwrite: ensure_folder_is_empty(self.folder_name) def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int): # Snapshot right after an OOM happened log.warning('Out Of Memory (OOM) observed') assert self.filename assert self.folder_name, 'folder_name must be set in init' filename = Path(self.folder_name) / Path( format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp), ) try: self.filename_config = SnapshotFileNameConfig.from_file_name(str(filename)) log.info(f'Dumping OOMObserver visualizations') snapshot = torch.cuda.memory._snapshot() # No data was recorded - avoids a `ValueError` in `trace_plot` if all(len(t) == 0 for t in snapshot['device_traces']): log.info(f'No allocation is recorded in memory snapshot)') return with open(self.filename_config.snapshot_file, 'wb') as fd: pickle.dump(snapshot, fd) with open(self.filename_config.trace_plot_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore with open(self.filename_config.segment_plot_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore with open(self.filename_config.segment_flamegraph_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore with open(self.filename_config.memory_flamegraph_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM') if self.remote_path_in_bucket is not None: for f in self.filename_config.list_filenames(): base_file_name = os.path.basename(f) remote_file_name = os.path.join(self.remote_path_in_bucket, base_file_name) remote_file_name = remote_file_name.lstrip('/') # remove leading slashes log.info(f'Uploading memory visualization to remote: {remote_file_name} from {f}') try: logger.upload_file(remote_file_name=remote_file_name, file_path=f, overwrite=self.overwrite) except FileExistsError as e: raise FileExistsError( f'Uploading memory visualizations failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory visualizations with Trainer, set save_overwrite to True.', ) from e except Exception as e: log.error(f'Failed to capture memory snapshot {e}') if self._enabled: torch.cuda.memory._record_memory_history( True, # type: ignore trace_alloc_max_entries=self.max_entries, trace_alloc_record_context=True, ) torch._C._cuda_attach_out_of_memory_observer(oom_observer) # type: ignore log.info('OOMObserver is enabled and registered')