# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Log to `MLFlow <https://www.mlflow.org/docs/latest/index.html>."""
from __future__ import annotations
import pathlib
from typing import Any, Dict, Optional, Union
from composer.core.state import State
from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import MissingConditionalImportError, dist
__all__ = ['MLFlowLogger']
[docs]class MLFlowLogger(LoggerDestination):
"""Log to `MLFlow <https://www.mlflow.org/docs/latest/index.html>`_.
Args:
experiment_name: (str, optional): MLFLow experiment name,
run_name: (str, optional): MLFlow run name.
tracking_uri (str | pathlib.Path, optional): MLFlow tracking uri, the URI to the
remote or local endpoint where logs are stored (If none it is set to `./mlruns`)
rank_zero_only (bool, optional): Whether to log only on the rank-zero process
(default: ``True``).
"""
def __init__(
self,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
tracking_uri: Optional[Union[str, pathlib.Path]] = None,
rank_zero_only: bool = True,
) -> None:
try:
import mlflow
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='mlflow',
conda_package='mlflow',
conda_channel='conda-forge') from e
del mlflow
self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0
self.run_name = run_name
self.experiment_name = experiment_name
self._rank_zero_only = rank_zero_only
self.tracking_uri = tracking_uri
def init(self, state: State, logger: Logger) -> None:
import mlflow
del logger # unused
# Use the logger run name if the name is not set.
if self.run_name is None:
self.run_name = state.run_name
# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
self.run_name += f'-rank{dist.get_global_rank()}'
if self.experiment_name is None:
self.experiment_name = 'my-mlflow-experiment'
if self._enabled:
if self.tracking_uri is not None:
mlflow.set_tracking_uri(self.tracking_uri)
mlflow.set_experiment(self.experiment_name)
mlflow.start_run(run_name=self.run_name)
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
import mlflow
if self._enabled:
# Convert all metrics to floats to placate mlflow.
metrics = {k: float(v) for k, v in metrics.items()}
mlflow.log_metrics(metrics=metrics, step=step)
def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
import mlflow
if self._enabled:
mlflow.log_params(params=hyperparameters)
def post_close(self):
import mlflow
if self._enabled:
mlflow.end_run()