Source code for composer.loggers.wandb_logger

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log to `Weights and Biases <>`_."""

from __future__ import annotations

import atexit
import copy
import os
import pathlib
import re
import sys
import tempfile
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

import numpy as np
import torch

from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import MissingConditionalImportError, dist

    from composer.core import State

__all__ = ['WandBLogger']

[docs]class WandBLogger(LoggerDestination): """Log to `Weights and Biases <>`_. Args: project (str, optional): WandB project name. group (str, optional): WandB group name. name (str, optional): WandB run name. If not specified, the :attr:`.State.run_name` will be used. entity (str, optional): WandB entity name. tags (List[str], optional): WandB tags. log_artifacts (bool, optional): Whether to log `artifacts <>`_ (Default: ``False``). rank_zero_only (bool, optional): Whether to log only on the rank-zero process. When logging `artifacts <>`_, it is highly recommended to log on all ranks. Artifacts from ranks โ‰ฅ1 will not be stored, which may discard pertinent information. For example, when using Deepspeed ZeRO, it would be impossible to restore from checkpoints without artifacts from all ranks (default: ``False``). init_kwargs (Dict[str, Any], optional): Any additional init kwargs ``wandb.init`` (see `WandB documentation <>`_). """ def __init__( self, project: Optional[str] = None, group: Optional[str] = None, name: Optional[str] = None, entity: Optional[str] = None, tags: Optional[List[str]] = None, log_artifacts: bool = False, rank_zero_only: bool = True, init_kwargs: Optional[Dict[str, Any]] = None, ) -> None: try: import wandb except ImportError as e: raise MissingConditionalImportError(extra_deps_group='wandb', conda_package='wandb', conda_channel='conda-forge') from e del wandb # unused if log_artifacts and rank_zero_only and dist.get_world_size() > 1: warnings.warn( ('When logging artifacts, `rank_zero_only` should be set to False. ' 'Artifacts from other ranks will not be collected, leading to a loss of information required to ' 'restore from checkpoints.')) self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 if init_kwargs is None: init_kwargs = {} if project is not None: init_kwargs['project'] = project if group is not None: init_kwargs['group'] = group if name is not None: init_kwargs['name'] = name if entity is not None: init_kwargs['entity'] = entity if tags is not None: init_kwargs['tags'] = tags self._rank_zero_only = rank_zero_only self._log_artifacts = log_artifacts self._init_kwargs = init_kwargs self._is_in_atexit = False # Set these variable directly to allow fetching an Artifact **without** initializing a WandB run # When used as a LoggerDestination, these values are overriden from global rank 0 to all ranks on Event.INIT self.entity = entity self.project = project self.run_dir: Optional[str] = None def _set_is_in_atexit(self): self._is_in_atexit = True def log_hyperparameters(self, hyperparameters: Dict[str, Any]): if self._enabled: import wandb wandb.config.update(hyperparameters) def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: if self._enabled: import wandb # wandb.log alters the metrics dictionary object, so we deepcopy to avoid # side effects. metrics_copy = copy.deepcopy(metrics) wandb.log(metrics_copy, step) def log_images( self, images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]], name: str = 'Images', channels_last: bool = False, step: Optional[int] = None, masks: Optional[Dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]]] = None, mask_class_labels: Optional[Dict[int, str]] = None, use_table: bool = False, ): if self._enabled: import wandb if not isinstance(images, Sequence) and images.ndim <= 3: images = [images] # _convert_to_wandb_image doesn't include wrapping with wandb.Image to future # proof for when we support masks. images_generator = (_convert_to_wandb_image(image, channels_last) for image in images) if masks is not None: # Create a generator that yields masks in the format wandb wants. wandb_masks_generator = _create_wandb_masks_generator(masks, mask_class_labels, channels_last=channels_last) wandb_images = ( wandb.Image(im, masks=mask_dict) for im, mask_dict in zip(images_generator, wandb_masks_generator)) else: wandb_images = (wandb.Image(image) for image in images_generator) if use_table: table = wandb.Table(columns=[name]) for wandb_image in wandb_images: table.add_data(wandb_image) wandb.log({name + ' Table': table}, step=step) else: wandb.log({name: list(wandb_images)}, step=step) def state_dict(self) -> Dict[str, Any]: import wandb # Storing these fields in the state dict to support run resuming in the future. if self._enabled: if is None: raise ValueError('wandb module must be initialized before serialization.') # If WandB is disabled, most things are RunDisabled objects, which are not # pickleable due to overriding __getstate__ but not __setstate__ if return {} else: return { 'name':, 'project':, 'entity':, 'id':, 'group': } else: return {} def init(self, state: State, logger: Logger) -> None: import wandb del logger # unused # Use the state run name if the name is not set. if 'name' not in self._init_kwargs or self._init_kwargs['name'] is None: self._init_kwargs['name'] = state.run_name # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: name = self._init_kwargs['name'] self._init_kwargs['name'] += f'-rank{dist.get_global_rank()}' self._init_kwargs['group'] = self._init_kwargs['group'] if 'group' in self._init_kwargs else name if self._enabled: wandb.init(**self._init_kwargs) assert is not None, 'The wandb run is set after init' entity_and_project = [str(, str(] self.run_dir = atexit.register(self._set_is_in_atexit) else: entity_and_project = [None, None] # Share the entity and project across all ranks, so they are available on ranks that did not initialize wandb dist.broadcast_object_list(entity_and_project) self.entity, self.project = entity_and_project assert self.entity is not None, 'entity should be defined' assert self.project is not None, 'project should be defined' def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool): del overwrite # unused if self._enabled and self._log_artifacts: import wandb # Some WandB-specific alias extraction timestamp = state.timestamp aliases = ['latest', f'ep{int(timestamp.epoch)}-ba{int(timestamp.batch)}'] # replace all unsupported characters with periods # Only alpha-numeric, periods, hyphens, and underscores are supported by wandb. new_remote_file_name = re.sub(r'[^a-zA-Z0-9-_\.]', '.', remote_file_name) if new_remote_file_name != remote_file_name: warnings.warn(('WandB permits only alpha-numeric, periods, hyphens, and underscores in file names. ' f"The file with name '{remote_file_name}' will be stored as '{new_remote_file_name}'.")) extension = new_remote_file_name.split('.')[-1] metadata = {f'timestamp/{k}': v for (k, v) in state.timestamp.state_dict().items()} # if evaluating, also log the evaluation timestamp if state.dataloader is not state.train_dataloader: # TODO If not actively training, then it is impossible to tell from the state whether # the trainer is evaluating or predicting. Assuming evaluation in this case. metadata.update({f'eval_timestamp/{k}': v for (k, v) in state.eval_timestamp.state_dict().items()}) # Change the extension so the checkpoint is compatible with W&B's model registry if extension == 'pt': extension = 'model' wandb_artifact = wandb.Artifact( name=new_remote_file_name, type=extension, metadata=metadata, ) wandb_artifact.add_file(os.path.abspath(file_path)) wandb.log_artifact(wandb_artifact, aliases=aliases)
[docs] def can_upload_files(self) -> bool: """Whether the logger supports uploading files.""" return True
def download_file( self, remote_file_name: str, destination: str, overwrite: bool = False, progress_bar: bool = True, ): # Note: WandB doesn't support progress bars for downloading del progress_bar import wandb import wandb.errors # using the wandb.Api() to support retrieving artifacts on ranks where # artifacts are not initialized api = wandb.Api() if not self.entity or not self.project: raise RuntimeError('get_file_artifact can only be called after running init()') # replace all unsupported characters with periods # Only alpha-numeric, periods, hyphens, and underscores are supported by wandb. if ':' not in remote_file_name: remote_file_name += ':latest' new_remote_file_name = re.sub(r'[^a-zA-Z0-9-_\.:]', '.', remote_file_name) if new_remote_file_name != remote_file_name: warnings.warn(('WandB permits only alpha-numeric, periods, hyphens, and underscores in file names. ' f"The file with name '{remote_file_name}' will be stored as '{new_remote_file_name}'.")) try: wandb_artifact = api.artifact('/'.join([self.entity, self.project, new_remote_file_name])) except wandb.errors.CommError as e: if 'does not contain artifact' in str(e): raise FileNotFoundError(f'WandB Artifact {new_remote_file_name} not found') from e raise e with tempfile.TemporaryDirectory() as tmpdir: wandb_artifact_folder = os.path.join(tmpdir, 'wandb_artifact_folder') wandb_artifact_names = os.listdir(wandb_artifact_folder) # We only log one file per artifact if len(wandb_artifact_names) > 1: raise RuntimeError( 'Found more than one file in WandB artifact. We assume the checkpoint is the only file in the WandB artifact.' ) wandb_artifact_name = wandb_artifact_names[0] wandb_artifact_path = os.path.join(wandb_artifact_folder, wandb_artifact_name) if overwrite: os.replace(wandb_artifact_path, destination) else: os.rename(wandb_artifact_path, destination) def post_close(self) -> None: import wandb # Cleaning up on post_close so all artifacts are uploaded if not self._enabled or is None or self._is_in_atexit: # Don't call wandb.finish if there is no run, or # the script is in an atexit, since wandb also hooks into atexit # and it will error if wandb.finish is called from the Composer atexit hook # after it is called from the wandb atexit hook return exc_tpe, exc_info, tb = sys.exc_info() if (exc_tpe, exc_info, tb) == (None, None, None): wandb.finish(0) else: # record there was an error wandb.finish(1)
def _convert_to_wandb_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: if isinstance(image, torch.Tensor): image = # Error out for empty arrays or weird arrays of dimension 0. if np.any(np.equal(image.shape, 0)): raise ValueError(f'Got an image (shape {image.shape}) with at least one dimension being 0! ') # Squeeze any singleton dimensions and then add them back in if image dimension # less than 3. image = image.squeeze() # Add in length-one dimensions to get back up to 3 # putting channels last. if image.ndim == 1: image = np.expand_dims(image, (1, 2)) channels_last = True if image.ndim == 2: image = np.expand_dims(image, 2) channels_last = True if image.ndim != 3: raise ValueError( textwrap.dedent(f'''Input image must be 3 dimensions, but instead got {image.ndim} dims at shape: {image.shape} Your input image was interpreted as a batch of {image.ndim} -dimensional images because you either specified a {image.ndim + 1}D image or a list of {image.ndim}D images. Please specify either a 4D image of a list of 3D images''')) assert isinstance(image, np.ndarray) if not channels_last: image = image.transpose(1, 2, 0) return image def _convert_to_wandb_mask(mask: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: mask = _convert_to_wandb_image(mask, channels_last) mask = mask.squeeze() if mask.ndim != 2: raise ValueError(f'Mask must be a 2D array, but instead got array of shape: {mask.shape}') return mask def _preprocess_mask_data(masks: Dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]], channels_last: bool) -> Dict[str, np.ndarray]: preprocesssed_masks = {} for mask_name, mask_data in masks.items(): if not isinstance(mask_data, Sequence): mask_data = mask_data.squeeze() if mask_data.ndim == 2: mask_data = [mask_data] preprocesssed_masks[mask_name] = np.stack([_convert_to_wandb_mask(mask, channels_last) for mask in mask_data]) return preprocesssed_masks def _create_wandb_masks_generator(masks: Dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]], mask_class_labels: Optional[Dict[int, str]], channels_last: bool): preprocessed_masks: Dict[str, np.ndarray] = _preprocess_mask_data(masks, channels_last) for all_masks_for_single_example in zip(*list(preprocessed_masks.values())): mask_dict = {name: {'mask_data': mask} for name, mask in zip(masks.keys(), all_masks_for_single_example)} if mask_class_labels is not None: for k in mask_dict.keys(): mask_dict[k].update({'class_labels': mask_class_labels}) yield mask_dict