Source code for composer.loggers.mlflow_logger

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log to `MLflow <https://www.mlflow.org/docs/latest/index.html>."""

from __future__ import annotations

import fnmatch
import logging
import multiprocessing
import os
import pathlib
import posixpath
import signal
import sys
import textwrap
import time
import warnings
from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Union

import numpy as np
import torch

from composer.core.state import State
from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import MissingConditionalImportError, dist

if TYPE_CHECKING:
    from mlflow import ModelVersion  # pyright: ignore[reportGeneralTypeIssues]

log = logging.getLogger(__name__)

__all__ = ['MLFlowLogger']

DEFAULT_MLFLOW_EXPERIMENT_NAME = 'my-mlflow-experiment'


class MlflowMonitorProcess(multiprocessing.Process):

    def __init__(self, main_pid, mlflow_run_id, mlflow_tracking_uri):
        super().__init__()
        self.main_pid = main_pid
        self.mlflow_run_id = mlflow_run_id
        self.mlflow_tracking_uri = mlflow_tracking_uri
        self.exit_event = multiprocessing.Event()
        self.crash_event = multiprocessing.Event()

    def handle_sigterm(self, signum, frame):
        from mlflow import MlflowClient
        client = MlflowClient(self.mlflow_tracking_uri)
        if client.get_run(self.mlflow_run_id).info.status == 'RUNNING':
            # Set the run status as KILLED if SIGTERM is received while the MLflow run is still
            # in status RUNNING.
            client.set_terminated(self.mlflow_run_id, status='KILLED')

    def run(self):
        from mlflow import MlflowClient

        os.setsid()
        # Register the signal handler in the child process
        signal.signal(signal.SIGTERM, self.handle_sigterm)

        while not self.exit_event.wait(10):
            try:
                # Signal 0 does not kill the process but performs error checking
                os.kill(self.main_pid, 0)
            except OSError:
                client = MlflowClient(self.mlflow_tracking_uri)
                client.set_terminated(self.mlflow_run_id, status='FAILED')
                break

        if self.crash_event.is_set():
            client = MlflowClient(self.mlflow_tracking_uri)
            client.set_terminated(self.mlflow_run_id, status='FAILED')

    def stop(self):
        self.exit_event.set()

    def crash(self):
        self.crash_event.set()
        self.exit_event.set()


[docs]class MLFlowLogger(LoggerDestination): """Log to `MLflow <https://www.mlflow.org/docs/latest/index.html>`_. Args: experiment_name: (str, optional): MLflow experiment name. If not set it will be use the MLflow environment variable or a default value run_name: (str, optional): MLflow run name. If not set it will be the same as the Trainer run name tags: (dict, optional): MLflow tags to log with the run tracking_uri (str | pathlib.Path, optional): MLflow tracking uri, the URI to the remote or local endpoint where logs are stored (If none it is set to MLflow default) rank_zero_only (bool, optional): Whether to log only on the rank-zero process (default: ``True``). flush_interval (int): The amount of time, in seconds, that MLflow must wait between logging batches of metrics. Any metrics that are recorded by Composer during this interval are enqueued, and the queue is flushed when the interval elapses (default: ``10``). model_registry_prefix (str, optional): The prefix to use when registering models. If registering to Unity Catalog, must be in the format ``{catalog_name}.{schema_name}``. (default: `''`) model_registry_uri (str, optional): The URI of the model registry to use. To register models to Unity Catalog, set to ``databricks-uc``. (default: None) synchronous (bool, optional): Whether to log synchronously. If ``True``, Mlflow will log synchronously to the MLflow backend. If ``False``, Mlflow will log asynchronously. (default: ``False``) log_system_metrics (bool, optional): Whether to log system metrics. If ``True``, Mlflow will log system metrics (CPU/GPU/memory/network usage) during training. (default: ``True``) rename_metrics (dict[str, str], optional): A dict to rename metrics, requires an exact match on the key (default: ``None``) ignore_metrics (list[str], optional): A list of glob patterns for metrics to ignore when logging. (default: ``None``) ignore_hyperparameters (list[str], optional): A list of glob patterns for hyperparameters to ignore when logging. (default: ``None``) run_group (str, optional): A string to group runs together. (default: ``None``) resume (bool, optional): If ``True``, Composer will search for an existing run tagged with the `run_name` and resume it. If no existing run is found, a new run will be created. If ``False``, Composer will create a new run. (default: ``False``) logging_buffer_seconds (int, optional): The amount of time, in seconds, that MLflow waits before sending logs to the MLflow tracking server. Metrics/params/tags logged within this buffer time will be grouped in batches before being sent to the backend. """ def __init__( self, experiment_name: Optional[str] = None, run_name: Optional[str] = None, tags: Optional[dict[str, Any]] = None, tracking_uri: Optional[Union[str, pathlib.Path]] = None, rank_zero_only: bool = True, flush_interval: int = 10, model_registry_prefix: str = '', model_registry_uri: Optional[str] = None, synchronous: bool = False, log_system_metrics: bool = True, rename_metrics: Optional[dict[str, str]] = None, ignore_metrics: Optional[list[str]] = None, ignore_hyperparameters: Optional[list[str]] = None, run_group: Optional[str] = None, resume: bool = False, logging_buffer_seconds: Optional[int] = 10, ) -> None: try: import mlflow from mlflow import MlflowClient except ImportError as e: raise MissingConditionalImportError( extra_deps_group='mlflow', conda_package='mlflow', conda_channel='conda-forge', ) from e self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 self.experiment_name = experiment_name self.run_name = run_name self.run_group = run_group self.tags = tags or {} self.model_registry_prefix = model_registry_prefix self.model_registry_uri = model_registry_uri self.synchronous = synchronous self.log_system_metrics = log_system_metrics self.rename_metrics = {} if rename_metrics is None else rename_metrics self.ignore_metrics = [] if ignore_metrics is None else ignore_metrics self.ignore_hyperparameters = [] if ignore_hyperparameters is None else ignore_hyperparameters if self.model_registry_uri == 'databricks-uc': if len(self.model_registry_prefix.split('.')) != 2: raise ValueError( f'When registering to Unity Catalog, model_registry_prefix must be in the format ' + f'{{catalog_name}}.{{schema_name}}, but got {self.model_registry_prefix}', ) self.resume = resume if logging_buffer_seconds: os.environ['MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS'] = str(logging_buffer_seconds,) if log_system_metrics: # Set system metrics sampling interval and samples before logging so that system metrics # are collected every 5s, and aggregated over 6 samples before being logged # (logging per 30s). mlflow.set_system_metrics_samples_before_logging(6) mlflow.set_system_metrics_sampling_interval(5) self._rank_zero_only = rank_zero_only self._last_flush_time = time.time() self._flush_interval = flush_interval self._experiment_id: Optional[str] = None self._run_id = None self.run_url = None if self._enabled: if tracking_uri is None and os.getenv('DATABRICKS_TOKEN') is not None: tracking_uri = 'databricks' if tracking_uri is None: tracking_uri = mlflow.get_tracking_uri() self.tracking_uri = str(tracking_uri) mlflow.set_tracking_uri(self.tracking_uri) if self.model_registry_uri is not None: mlflow.set_registry_uri(self.model_registry_uri) # Set up MLflow state self._run_id = None if self.experiment_name is None: self.experiment_name = os.getenv( mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, # type: ignore DEFAULT_MLFLOW_EXPERIMENT_NAME, ) assert self.experiment_name is not None # type hint if os.getenv( 'DATABRICKS_TOKEN', ) is not None and not self.experiment_name.startswith(( '/Users/', '/Shared/', )): try: from databricks.sdk import WorkspaceClient except ImportError as e: raise MissingConditionalImportError( extra_deps_group='mlflow', conda_package='databricks-sdk', conda_channel='conda-forge', ) from e databricks_username = WorkspaceClient().current_user.me().user_name or '' self.experiment_name = os.path.join( '/Users', databricks_username, self.experiment_name.strip('/'), ) self._mlflow_client = MlflowClient(self.tracking_uri) # Set experiment env_exp_id = os.getenv( mlflow.environment_variables.MLFLOW_EXPERIMENT_ID.name, # pyright: ignore[reportGeneralTypeIssues] None, ) if env_exp_id is not None: self._experiment_id = env_exp_id else: exp_from_name = self._mlflow_client.get_experiment_by_name(name=self.experiment_name) if exp_from_name is not None: self._experiment_id = exp_from_name.experiment_id else: self._experiment_id = self._mlflow_client.create_experiment(name=self.experiment_name) def _start_mlflow_run(self, state): import mlflow # This function is only called if self._enabled is True, and therefore self._experiment_id is not None. assert self._experiment_id is not None env_run_id = os.getenv( mlflow.environment_variables.MLFLOW_RUN_ID.name, # pyright: ignore[reportGeneralTypeIssues] None, ) if env_run_id is not None: self._run_id = env_run_id elif self.resume: # Search for an existing run tagged with this Composer run if `self.resume=True`. run_name = self.tags['run_name'] existing_runs = mlflow.search_runs( experiment_ids=[self._experiment_id], filter_string=f'tags.run_name = "{run_name}"', output_format='list', ) if len(existing_runs) > 0: self._run_id = existing_runs[0].info.run_id log.debug(f'Resuming mlflow run with run id: {self._run_id}') else: log.debug( 'Creating a new mlflow run as `resume` was set to True but no previous run was ' 'found.', ) new_run = self._mlflow_client.create_run( experiment_id=self._experiment_id, run_name=self.run_name, ) self._run_id = new_run.info.run_id else: # Create a new run if `env_run_id` is not set or `self.resume=False`. new_run = self._mlflow_client.create_run( experiment_id=self._experiment_id, run_name=self.run_name, ) self._run_id = new_run.info.run_id tags = self.tags or {} if self.run_group: tags['run_group'] = self.run_group mlflow.start_run( run_id=self._run_id, tags=self.tags, log_system_metrics=self.log_system_metrics, ) if self.tracking_uri == 'databricks': # Start a background process to monitor the job to report the job status to MLflow. self.monitor_process = MlflowMonitorProcess( os.getpid(), self._run_id, self.tracking_uri, ) self.monitor_process.start() def _global_exception_handler(self, exc_type, exc_value, exc_traceback): """Catch global exception.""" self._global_exception_occurred += 1 sys.__excepthook__(exc_type, exc_value, exc_traceback) def init(self, state: State, logger: Logger) -> None: del logger # unused if self.run_name is None: self.run_name = state.run_name self._global_exception_occurred = 0 # Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume self.tags['run_name'] = os.environ.get('RUN_NAME', state.run_name) # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: self.run_name += f'-rank{dist.get_global_rank()}' # Register the global exception handler so that uncaught exception is tracked. sys.excepthook = self._global_exception_handler # Start run if self._enabled: self._start_mlflow_run(state) # If rank zero only, broadcast the MLFlow experiment and run IDs to other ranks, so the MLFlow run info is # available to other ranks during runtime. if self._rank_zero_only: mlflow_ids_list = [self._experiment_id, self._run_id] dist.broadcast_object_list(mlflow_ids_list, src=0) self._experiment_id, self._run_id = mlflow_ids_list def after_load(self, state: State, logger: Logger) -> None: logger.log_hyperparameters({ 'mlflow_experiment_id': self._experiment_id, 'mlflow_run_id': self._run_id, }) self.run_url = posixpath.join( os.environ.get('DATABRICKS_HOST', ''), 'ml', 'experiments', str(self._experiment_id), 'runs', str(self._run_id), ) def log_table( self, columns: list[str], rows: list[list[Any]], name: str = 'Table', step: Optional[int] = None, ) -> None: del step if self._enabled: try: import pandas as pd except ImportError as e: raise MissingConditionalImportError( extra_deps_group='pandas', conda_package='pandas', conda_channel='conda-forge', ) from e table = pd.DataFrame.from_records(data=rows, columns=columns) assert isinstance(self._run_id, str) self._mlflow_client.log_table( run_id=self._run_id, data=table, artifact_file=f'{name}.json', ) def rename(self, key: str): return self.rename_metrics.get(key, key) def log_metrics(self, metrics: dict[str, Any], step: Optional[int] = None) -> None: from mlflow import log_metrics if self._enabled: # Convert all metrics to floats to placate mlflow. metrics = { self.rename(k): float(v) for k, v in metrics.items() if not any(fnmatch.fnmatch(k, pattern) for pattern in self.ignore_metrics) } log_metrics( metrics=metrics, step=step, synchronous=self.synchronous, ) def log_hyperparameters(self, hyperparameters: dict[str, Any]): from mlflow import log_params if self._enabled: hyperparameters = { k: v for k, v in hyperparameters.items() if not any(fnmatch.fnmatch(k, pattern) for pattern in self.ignore_hyperparameters) } log_params( params=hyperparameters, synchronous=self.synchronous, )
[docs] def register_model( self, model_uri: str, name: str, await_registration_for: int = 300, tags: Optional[dict[str, Any]] = None, ) -> 'ModelVersion': """Register a model to model registry. Args: model_uri (str): The URI of the model to register. name (str): The name of the model to register. Will be appended to ``model_registry_prefix``. await_registration_for (int, optional): The number of seconds to wait for the model to be registered. Defaults to 300. tags (Optional[dict[str, Any]], optional): A dictionary of tags to add to the model. Defaults to None. registry_uri (str, optional): The URI of the model registry. Defaults to `None` which will register to the Databricks Unity Catalog. Returns: ModelVersion: The registered model. """ if self._enabled: full_name = f'{self.model_registry_prefix}.{name}' if len(self.model_registry_prefix) > 0 else name import mlflow return mlflow.register_model( model_uri=model_uri, name=full_name, await_registration_for=await_registration_for, tags=tags, )
[docs] def save_model(self, flavor: Literal['transformers', 'peft'], **kwargs): """Save a model to MLflow. Note: The ``'peft'`` flavor is experimental and the API is subject to change without warning. Args: flavor (Literal['transformers', 'peft']): The MLflow model flavor to use. Currently only ``'transformers'`` and ``'peft'`` are supported. **kwargs: Keyword arguments to pass to the MLflow model saving function. Raises: NotImplementedError: If ``flavor`` is not ``'transformers'`` or ``'peft'``. """ if self._enabled: import mlflow if flavor == 'transformers': mlflow.transformers.save_model(**kwargs) elif flavor == 'peft': import transformers # TODO: Remove after mlflow fixes the bug that makes this necessary mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' # type: ignore # This is a temporary workaround until MLflow adds full support for saving PEFT models. # https://github.com/mlflow/mlflow/issues/9256 log.warning( 'Saving PEFT models using MLflow is experimental and the API is subject to change without warning.', ) expected_keys = {'path', 'save_pretrained_dir'} if not expected_keys.issubset(kwargs.keys()): raise ValueError(f'Expected keys {expected_keys} but got {kwargs.keys()}') # This does not implement predict for now, as we will wait for the full MLflow support # for PEFT models. class PeftModel(mlflow.pyfunc.PythonModel): def load_context(self, context): self.model = transformers.AutoModelForCausalLM.from_pretrained( context.artifacts['lora_checkpoint'], ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( context.artifacts['lora_checkpoint'], ) from mlflow.models.signature import ModelSignature from mlflow.types import ColSpec, DataType, Schema # This is faked for now, until MLflow adds full support for saving PEFT models. input_schema = Schema([ ColSpec(DataType.string, 'fake_input'), ]) output_schema = Schema([ColSpec(DataType.string)]) signature = ModelSignature(inputs=input_schema, outputs=output_schema) # Symlink the directory so that we control the path that MLflow saves the model under os.symlink(kwargs['save_pretrained_dir'], 'lora_checkpoint') mlflow.pyfunc.save_model( path=kwargs['path'], artifacts={'lora_checkpoint': 'lora_checkpoint'}, python_model=PeftModel(), signature=signature, ) os.unlink('lora_checkpoint') else: raise NotImplementedError(f'flavor {flavor} not supported.')
[docs] def log_model(self, flavor: Literal['transformers'], **kwargs): """Log a model to MLflow. Args: flavor (Literal['transformers']): The MLflow model flavor to use. Currently only ``'transformers'`` is supported. **kwargs: Keyword arguments to pass to the MLflow model logging function. Raises: NotImplementedError: If ``flavor`` is not ``'transformers'``. """ if self._enabled: import mlflow if flavor == 'transformers': mlflow.transformers.log_model(**kwargs) else: raise NotImplementedError(f'flavor {flavor} not supported.')
[docs] def register_model_with_run_id( self, model_uri: str, name: str, await_creation_for: int = 300, tags: Optional[dict[str, Any]] = None, ): """Similar to ``register_model``, but uses a different MLflow API to allow passing in the run id. Args: model_uri (str): The URI of the model to register. name (str): The name of the model to register. Will be appended to ``model_registry_prefix``. await_creation_for (int, optional): The number of seconds to wait for the model to be registered. Defaults to 300. tags (Optional[dict[str, Any]], optional): A dictionary of tags to add to the model. Defaults to None. """ if self._enabled: from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import ( ALREADY_EXISTS, RESOURCE_ALREADY_EXISTS, ErrorCode, ) full_name = f'{self.model_registry_prefix}.{name}' if len(self.model_registry_prefix) > 0 else name # This try/catch code is copied from # https://github.com/mlflow/mlflow/blob/3ba1e50e90a38be19920cb9118593a43d7cfa90e/mlflow/tracking/_model_registry/fluent.py#L90-L103 try: create_model_response = self._mlflow_client.create_registered_model(full_name) log.info(f'Successfully registered model {name} with {create_model_response.name}') except MlflowException as e: if e.error_code in ( ErrorCode.Name(RESOURCE_ALREADY_EXISTS), ErrorCode.Name(ALREADY_EXISTS), ): log.info(f'Registered model {name} already exists. Creating a new version of this model...') else: raise e create_version_response = self._mlflow_client.create_model_version( name=full_name, source=model_uri, run_id=self._run_id, await_creation_for=await_creation_for, tags=tags, ) log.info( f'Successfully created model version {create_version_response.version} for model {create_version_response.name}', )
def log_images( self, images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]], name: str = 'image', channels_last: bool = False, step: Optional[int] = None, masks: Optional[dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]]] = None, mask_class_labels: Optional[dict[int, str]] = None, use_table: bool = True, ): unused_args = (masks, mask_class_labels) # Unused (only for wandb) if any(unused_args): warnings.warn( textwrap.dedent( f"""MLFlowLogger does not support masks, class labels, or tables of images, but got masks={masks}, mask_class_labels={mask_class_labels}""", ), ) if self._enabled: if not isinstance(images, Sequence) and images.ndim <= 3: images = [images] for im_ind, image in enumerate(images): image = _convert_to_mlflow_image(image, channels_last) assert isinstance(self._run_id, str) self._mlflow_client.log_image( image=image, key=f'{name}_{im_ind}', run_id=self._run_id, step=step, ) def post_close(self): if self._enabled: if hasattr(self, 'monitor_process'): # Check if there is an uncaught exception, which means `post_close()` is triggered # due to program crash. finish_with_exception = self._global_exception_occurred == 1 if finish_with_exception: self.monitor_process.crash() return # Stop the monitor process since it's entering the cleanup phase. self.monitor_process.stop() import mlflow assert isinstance(self._run_id, str) mlflow.flush_async_logging() exc_tpe, exc_info, tb = sys.exc_info() if (exc_tpe, exc_info, tb) == (None, None, None): current_status = self._mlflow_client.get_run(self._run_id).info.status if current_status == 'RUNNING': self._mlflow_client.set_terminated(self._run_id, status='FINISHED') else: # Record there was an error self._mlflow_client.set_terminated(self._run_id, status='FAILED') mlflow.end_run() if hasattr(self, 'monitor_process'): self.monitor_process.join()
def _convert_to_mlflow_image( image: Union[np.ndarray, torch.Tensor], channels_last: bool, ) -> np.ndarray: if isinstance(image, torch.Tensor): image = image.data.cpu().numpy() # Error out for empty arrays or weird arrays of dimension 0. if np.any(np.equal(image.shape, 0)): raise ValueError(f'Got an image (shape {image.shape}) with at least one dimension being 0!') # Squeeze any singleton dimensions and then add them back in if image dimension # less than 3. image = image.squeeze() # Add in length-one dimensions to get back up to 3 # putting channels last. if image.ndim == 1: image = np.expand_dims(image, (1, 2)) channels_last = True if image.ndim == 2: image = np.expand_dims(image, 2) channels_last = True if image.ndim != 3: raise ValueError( textwrap.dedent( f'''Input image must be 3 dimensions, but instead got {image.ndim} dims at shape: {image.shape} Your input image was interpreted as a batch of {image.ndim} -dimensional images because you either specified a {image.ndim + 1}D image or a list of {image.ndim}D images. Please specify either a 4D image of a list of 3D images''', ), ) assert isinstance(image, np.ndarray) if not channels_last: image = image.transpose(1, 2, 0) if image.shape[-1] not in [1, 3, 4]: raise ValueError( textwrap.dedent( f'''Input image must have 1, 3, or 4 channels, but instead got {image.shape[-1]} channels at shape: {image.shape} Please specify either a 1-, 3-, or 4-channel image or a list of 1-, 3-, or 4-channel images''', ), ) return image