Source code for composer.loggers.mlflow_logger

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log to `MLFlow <https://www.mlflow.org/docs/latest/index.html>."""

from __future__ import annotations

import pathlib
from typing import Any, Dict, Optional, Union

from composer.core.state import State
from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import MissingConditionalImportError, dist

__all__ = ['MLFlowLogger']


[docs]class MLFlowLogger(LoggerDestination): """Log to `MLFlow <https://www.mlflow.org/docs/latest/index.html>`_. Args: experiment_name: (str, optional): MLFLow experiment name, run_name: (str, optional): MLFlow run name. tracking_uri (str | pathlib.Path, optional): MLFlow tracking uri, the URI to the remote or local endpoint where logs are stored (If none it is set to `./mlruns`) rank_zero_only (bool, optional): Whether to log only on the rank-zero process (default: ``True``). """ def __init__( self, experiment_name: Optional[str] = None, run_name: Optional[str] = None, tracking_uri: Optional[Union[str, pathlib.Path]] = None, rank_zero_only: bool = True, ) -> None: try: import mlflow except ImportError as e: raise MissingConditionalImportError(extra_deps_group='mlflow', conda_package='mlflow', conda_channel='conda-forge') from e del mlflow self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 self.run_name = run_name self.experiment_name = experiment_name self._rank_zero_only = rank_zero_only self.tracking_uri = tracking_uri def init(self, state: State, logger: Logger) -> None: import mlflow del logger # unused # Use the logger run name if the name is not set. if self.run_name is None: self.run_name = state.run_name # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: self.run_name += f'-rank{dist.get_global_rank()}' if self.experiment_name is None: self.experiment_name = 'my-mlflow-experiment' if self._enabled: if self.tracking_uri is not None: mlflow.set_tracking_uri(self.tracking_uri) mlflow.set_experiment(self.experiment_name) mlflow.start_run(run_name=self.run_name) def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: import mlflow if self._enabled: # Convert all metrics to floats to placate mlflow. metrics = {k: float(v) for k, v in metrics.items()} mlflow.log_metrics(metrics=metrics, step=step) def log_hyperparameters(self, hyperparameters: Dict[str, Any]): import mlflow if self._enabled: mlflow.log_params(params=hyperparameters) def post_close(self): import mlflow if self._enabled: mlflow.end_run()