Source code for composer.loggers.wandb_logger

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log to `Weights and Biases <https://wandb.ai/>`_."""

from __future__ import annotations

import atexit
import copy
import os
import pathlib
import re
import sys
import tempfile
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Optional, Sequence, Union

import numpy as np
import torch

from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import MissingConditionalImportError, dist

if TYPE_CHECKING:
    from composer.core import State

__all__ = ['WandBLogger']


[docs]class WandBLogger(LoggerDestination): """Log to `Weights and Biases <https://wandb.ai/>`_. Args: project (str, optional): WandB project name. group (str, optional): WandB group name. name (str, optional): WandB run name. If not specified, the :attr:`.State.run_name` will be used. entity (str, optional): WandB entity name. tags (list[str], optional): WandB tags. log_artifacts (bool, optional): Whether to log `artifacts <https://docs.wandb.ai/ref/python/artifact>`_ (Default: ``False``). rank_zero_only (bool, optional): Whether to log only on the rank-zero process. When logging `artifacts <https://docs.wandb.ai/ref/python/artifact>`_, it is highly recommended to log on all ranks. Artifacts from ranks โ‰ฅ1 will not be stored, which may discard pertinent information (default: ``True``). init_kwargs (dict[str, Any], optional): Any additional init kwargs ``wandb.init`` (see `WandB documentation <https://docs.wandb.ai/ref/python/init>`_). """ def __init__( self, project: Optional[str] = None, group: Optional[str] = None, name: Optional[str] = None, entity: Optional[str] = None, tags: Optional[list[str]] = None, log_artifacts: bool = False, rank_zero_only: bool = True, init_kwargs: Optional[dict[str, Any]] = None, ) -> None: try: import wandb except ImportError as e: raise MissingConditionalImportError( extra_deps_group='wandb', conda_package='wandb', conda_channel='conda-forge', ) from e del wandb # unused if log_artifacts and rank_zero_only and dist.get_world_size() > 1: warnings.warn(( 'When logging artifacts, `rank_zero_only` should be set to False. ' 'Artifacts from other ranks will not be collected, leading to a loss of information required to ' 'restore from checkpoints.' )) self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 if init_kwargs is None: init_kwargs = {} if project is not None: init_kwargs['project'] = project if group is not None: init_kwargs['group'] = group if name is not None: init_kwargs['name'] = name if entity is not None: init_kwargs['entity'] = entity if tags is not None: init_kwargs['tags'] = tags self._rank_zero_only = rank_zero_only self._log_artifacts = log_artifacts self._init_kwargs = init_kwargs self._is_in_atexit = False # Set these variable directly to allow fetching an Artifact **without** initializing a WandB run # When used as a LoggerDestination, these values are overriden from global rank 0 to all ranks on Event.INIT self.entity = entity self.project = project self.run_dir: Optional[str] = None self.run_url: Optional[str] = None self.table_dict = {} def _set_is_in_atexit(self): self._is_in_atexit = True def log_hyperparameters(self, hyperparameters: dict[str, Any]): if self._enabled: import wandb wandb.config.update(hyperparameters) def log_table( self, columns: list[str], rows: list[list[Any]], name: str = 'Table', step: Optional[int] = None, ) -> None: if self._enabled: import wandb table = wandb.Table(columns=columns, rows=rows) wandb.log({name: table}, step=step) def log_metrics(self, metrics: dict[str, Any], step: Optional[int] = None) -> None: if self._enabled: import wandb # wandb.log alters the metrics dictionary object, so we deepcopy to avoid # side effects. metrics_copy = copy.deepcopy(metrics) wandb.log(metrics_copy, step) def log_images( self, images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]], name: str = 'Images', channels_last: bool = False, step: Optional[int] = None, masks: Optional[dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]]] = None, mask_class_labels: Optional[dict[int, str]] = None, use_table: bool = False, ): if self._enabled: import wandb if not isinstance(images, Sequence) and images.ndim <= 3: images = [images] # _convert_to_wandb_image doesn't include wrapping with wandb.Image to future # proof for when we support masks. images_generator = (_convert_to_wandb_image(image, channels_last) for image in images) if masks is not None: # Create a generator that yields masks in the format wandb wants. wandb_masks_generator = _create_wandb_masks_generator( masks, mask_class_labels, channels_last=channels_last, ) wandb_images = ( wandb.Image(im, masks=mask_dict) for im, mask_dict in zip(images_generator, wandb_masks_generator) ) else: wandb_images = (wandb.Image(image) for image in images_generator) if use_table: table = wandb.Table(columns=[name]) for wandb_image in wandb_images: table.add_data(wandb_image) wandb.log({name + ' Table': table}, step=step) else: wandb.log({name: list(wandb_images)}, step=step) def init(self, state: State, logger: Logger) -> None: import wandb del logger # unused # Use the state run name if the name is not set. if 'name' not in self._init_kwargs or self._init_kwargs['name'] is None: self._init_kwargs['name'] = state.run_name # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: name = self._init_kwargs['name'] self._init_kwargs['name'] += f'-rank{dist.get_global_rank()}' self._init_kwargs['group'] = self._init_kwargs['group'] if 'group' in self._init_kwargs else name if self._enabled: wandb.init(**self._init_kwargs) assert wandb.run is not None, 'The wandb run is set after init' if hasattr(wandb.run, 'entity') and hasattr(wandb.run, 'project'): entity_and_project = [str(wandb.run.entity), str(wandb.run.project)] else: # Run does not have attribtues if wandb is in disabled mode, so we must mock it entity_and_project = ['disabled', 'disabled'] self.run_dir = wandb.run.dir self.run_url = wandb.run.get_url() atexit.register(self._set_is_in_atexit) else: entity_and_project = [None, None] # Share the entity and project across all ranks, so they are available on ranks that did not initialize wandb dist.broadcast_object_list(entity_and_project) self.entity, self.project = entity_and_project assert self.entity is not None, 'entity should be defined' assert self.project is not None, 'project should be defined' def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool): del overwrite # unused if self._enabled and self._log_artifacts: import wandb # Some WandB-specific alias extraction timestamp = state.timestamp aliases = ['latest', f'ep{int(timestamp.epoch)}-ba{int(timestamp.batch)}'] # replace all unsupported characters with periods # Only alpha-numeric, periods, hyphens, and underscores are supported by wandb. new_remote_file_name = re.sub(r'[^a-zA-Z0-9-_\.]', '.', remote_file_name) if new_remote_file_name != remote_file_name: warnings.warn(( 'WandB permits only alpha-numeric, periods, hyphens, and underscores in file names. ' f"The file with name '{remote_file_name}' will be stored as '{new_remote_file_name}'." )) extension = new_remote_file_name.split('.')[-1] metadata = {f'timestamp/{k}': v for (k, v) in state.timestamp.state_dict().items()} # if evaluating, also log the evaluation timestamp if state.dataloader is not state.train_dataloader: # TODO If not actively training, then it is impossible to tell from the state whether # the trainer is evaluating or predicting. Assuming evaluation in this case. metadata.update({f'eval_timestamp/{k}': v for (k, v) in state.eval_timestamp.state_dict().items()}) # Change the extension so the checkpoint is compatible with W&B's model registry if extension == 'pt': extension = 'model' wandb_artifact = wandb.Artifact( name=new_remote_file_name, type=extension, metadata=metadata, ) wandb_artifact.add_file(os.path.abspath(file_path)) wandb.log_artifact(wandb_artifact, aliases=aliases)
[docs] def can_upload_files(self) -> bool: """Whether the logger supports uploading files.""" return True
def download_file( self, remote_file_name: str, destination: str, overwrite: bool = False, progress_bar: bool = True, ): # Note: WandB doesn't support progress bars for downloading del progress_bar import wandb import wandb.errors # using the wandb.Api() to support retrieving artifacts on ranks where # artifacts are not initialized api = wandb.Api() if not self.entity or not self.project: raise RuntimeError('get_file_artifact can only be called after running init()') # replace all unsupported characters with periods # Only alpha-numeric, periods, hyphens, and underscores are supported by wandb. if ':' not in remote_file_name: remote_file_name += ':latest' new_remote_file_name = re.sub(r'[^a-zA-Z0-9-_\.:]', '.', remote_file_name) if new_remote_file_name != remote_file_name: warnings.warn(( 'WandB permits only alpha-numeric, periods, hyphens, and underscores in file names. ' f"The file with name '{remote_file_name}' will be stored as '{new_remote_file_name}'." )) try: wandb_artifact = api.artifact('/'.join([self.entity, self.project, new_remote_file_name])) except wandb.errors.CommError as e: raise FileNotFoundError(f'WandB Artifact {new_remote_file_name} not found') from e with tempfile.TemporaryDirectory() as tmpdir: wandb_artifact_folder = os.path.join(tmpdir, 'wandb_artifact_folder/') wandb_artifact.download(root=wandb_artifact_folder) wandb_artifact_names = os.listdir(wandb_artifact_folder) # We only log one file per artifact if len(wandb_artifact_names) > 1: raise RuntimeError( 'Found more than one file in WandB artifact. We assume the checkpoint is the only file in the WandB artifact.', ) wandb_artifact_name = wandb_artifact_names[0] wandb_artifact_path = os.path.join(wandb_artifact_folder, wandb_artifact_name) if overwrite: os.replace(wandb_artifact_path, destination) else: os.rename(wandb_artifact_path, destination) def post_close(self) -> None: import wandb # Cleaning up on post_close so all artifacts are uploaded if not self._enabled or wandb.run is None or self._is_in_atexit: # Don't call wandb.finish if there is no run, or # the script is in an atexit, since wandb also hooks into atexit # and it will error if wandb.finish is called from the Composer atexit hook # after it is called from the wandb atexit hook return exc_tpe, exc_info, tb = sys.exc_info() if (exc_tpe, exc_info, tb) == (None, None, None): wandb.finish(0) else: # record there was an error wandb.finish(1)
def _convert_to_wandb_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: if isinstance(image, torch.Tensor): if image.dtype == torch.float16 or image.dtype == torch.bfloat16: image = image.data.cpu().to(torch.float32).numpy() else: image = image.data.cpu().numpy() # Error out for empty arrays or weird arrays of dimension 0. if np.any(np.equal(image.shape, 0)): raise ValueError(f'Got an image (shape {image.shape}) with at least one dimension being 0! ') # Squeeze any singleton dimensions and then add them back in if image dimension # less than 3. image = image.squeeze() # Add in length-one dimensions to get back up to 3 # putting channels last. if image.ndim == 1: image = np.expand_dims(image, (1, 2)) channels_last = True if image.ndim == 2: image = np.expand_dims(image, 2) channels_last = True if image.ndim != 3: raise ValueError( textwrap.dedent( f'''Input image must be 3 dimensions, but instead got {image.ndim} dims at shape: {image.shape} Your input image was interpreted as a batch of {image.ndim} -dimensional images because you either specified a {image.ndim + 1}D image or a list of {image.ndim}D images. Please specify either a 4D image of a list of 3D images''', ), ) assert isinstance(image, np.ndarray) if not channels_last: image = image.transpose(1, 2, 0) return image def _convert_to_wandb_mask(mask: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: mask = _convert_to_wandb_image(mask, channels_last) mask = mask.squeeze() if mask.ndim != 2: raise ValueError(f'Mask must be a 2D array, but instead got array of shape: {mask.shape}') return mask def _preprocess_mask_data( masks: dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]], channels_last: bool, ) -> dict[str, np.ndarray]: preprocesssed_masks = {} for mask_name, mask_data in masks.items(): if not isinstance(mask_data, Sequence): mask_data = mask_data.squeeze() if mask_data.ndim == 2: mask_data = [mask_data] preprocesssed_masks[mask_name] = np.stack([_convert_to_wandb_mask(mask, channels_last) for mask in mask_data]) return preprocesssed_masks def _create_wandb_masks_generator( masks: dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]], mask_class_labels: Optional[dict[int, str]], channels_last: bool, ): preprocessed_masks: dict[str, np.ndarray] = _preprocess_mask_data(masks, channels_last) for all_masks_for_single_example in zip(*list(preprocessed_masks.values())): mask_dict = {name: {'mask_data': mask} for name, mask in zip(masks.keys(), all_masks_for_single_example)} if mask_class_labels is not None: for k in mask_dict.keys(): mask_dict[k].update({'class_labels': mask_class_labels}) yield mask_dict