# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""MLflow Artifacts object store."""
from __future__ import annotations
import logging
import os
import pathlib
import tempfile
from typing import Callable, List, Optional, Tuple, Union
from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError
__all__ = ['MLFlowObjectStore']
MLFLOW_DATABRICKS_TRACKING_URI = 'databricks'
MLFLOW_DBFS_PATH_PREFIX = 'databricks/mlflow-tracking/'
DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store'
MLFLOW_EXPERIMENT_ID_FORMAT_KEY = 'mlflow_experiment_id'
MLFLOW_RUN_ID_FORMAT_KEY = 'mlflow_run_id'
MLFLOW_EXPERIMENT_ID_PLACEHOLDER = '{' + MLFLOW_EXPERIMENT_ID_FORMAT_KEY + '}'
MLFLOW_RUN_ID_PLACEHOLDER = '{' + MLFLOW_RUN_ID_FORMAT_KEY + '}'
log = logging.getLogger(__name__)
def _wrap_mlflow_exceptions(uri: str, e: Exception):
"""Wraps retryable MLflow errors in ObjectStoreTransientError for automatic retry handling."""
from mlflow.exceptions import (
ABORTED,
DATA_LOSS,
DEADLINE_EXCEEDED,
ENDPOINT_NOT_FOUND,
INTERNAL_ERROR,
INVALID_STATE,
NOT_FOUND,
REQUEST_LIMIT_EXCEEDED,
RESOURCE_DOES_NOT_EXIST,
RESOURCE_EXHAUSTED,
TEMPORARILY_UNAVAILABLE,
ErrorCode,
MlflowException,
)
# https://github.com/mlflow/mlflow/blob/39b76b5b05407af5d223e892b03e450b7264576a/mlflow/exceptions.py for used error codes.
# https://github.com/mlflow/mlflow/blob/39b76b5b05407af5d223e892b03e450b7264576a/mlflow/protos/databricks.proto for code descriptions.
retryable_server_codes = [
ErrorCode.Name(code) for code in [
DATA_LOSS,
INTERNAL_ERROR,
INVALID_STATE,
TEMPORARILY_UNAVAILABLE,
DEADLINE_EXCEEDED,
]
]
retryable_client_codes = [ErrorCode.Name(code) for code in [ABORTED, REQUEST_LIMIT_EXCEEDED, RESOURCE_EXHAUSTED]]
not_found_codes = [ErrorCode.Name(code) for code in [RESOURCE_DOES_NOT_EXIST, NOT_FOUND, ENDPOINT_NOT_FOUND]]
if isinstance(e, MlflowException):
error_code = e.error_code # pyright: ignore
if error_code in retryable_server_codes or error_code in retryable_client_codes:
raise ObjectStoreTransientError(error_code) from e
elif error_code in not_found_codes:
raise FileNotFoundError(f'Object {uri} not found') from e
raise e
[docs]class MLFlowObjectStore(ObjectStore):
"""Utility class for uploading and downloading artifacts from MLflow.
It can be initializd for an existing run, a new run in an existing experiment, the active run used by the `mlflow`
module, or a new run in a new experiment. See the documentation for ``path`` for more details.
.. note::
At this time, only Databricks-managed MLflow with a 'databricks' tracking URI is supported.
Using this object store requires configuring Databricks authentication through a configuration file or
environment variables. See
https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication
Unlike other object stores, the DBFS URI scheme for MLflow artifacts has no bucket, and the path is prefixed
with the artifacts root directory for a given experiment/run,
`databricks/mlflow-tracking/<experiment_id>/<run_id>/`. However, object names are also sometimes passed by
upstream code as artifact paths relative to this root, rather than the full path. To keep upstream code simple,
:class:`MLFlowObjectStore` accepts both relative MLflow artifact paths and absolute DBFS paths as object names.
If an object name takes the form of
`databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<artifact_path>`,
it is assumed to be an absolute DBFS path, and the `<artifact_path>` is used when uploading objects to MLflow.
Otherwise, the object name is assumed to be a relative MLflow artifact path, and the full provided name will be
used as the artifact path when uploading to MLflow.
Args:
path (str): A DBFS path of the form
`databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<path>`.
`experiment_id` and `run_id` can be set as the format string placeholders
`{mlflow_experiment_id}` and `{mlflow_run_id}'`.
If both `experiment_id` and `run_id` are set as placeholders, the MLFlowObjectStore will be associated with
the currently active MLflow run if one exists. If no active run exists, a new run will be created under a
default experiment name, or the experiment name specified by the `MLFLOW_EXPERIMENT_NAME` environment
variable if one is set.
If `experiment_id` is provided and `run_id` is not, the MLFlowObjectStore will create a new run in the
provided experiment.
Providing a `run_id` without an `experiment_id` will raise an error.
multipart_upload_chunk_size(int, optional): The maximum size of a single chunk in an MLflow multipart upload.
The maximum number of chunks supported by MLflow is 10,000, so the max file size that can
be uploaded is `10 000 * multipart_upload_chunk_size`. Defaults to 100MB for a max upload size of 1TB.
"""
def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 1024) -> None:
try:
import mlflow
from mlflow import MlflowClient
except ImportError as e:
raise MissingConditionalImportError('mlflow', conda_package='mlflow>=2.9.2,<3.0') from e
try:
from databricks.sdk import WorkspaceClient
except ImportError as e:
raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e
tracking_uri = os.getenv(
mlflow.environment_variables.MLFLOW_TRACKING_URI.name, # pyright: ignore[reportGeneralTypeIssues]
MLFLOW_DATABRICKS_TRACKING_URI,
)
if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI:
raise ValueError(
'MLFlowObjectStore currently only supports Databricks-hosted MLflow tracking. '
f'Environment variable `MLFLOW_TRACKING_URI` is set to a non-Databricks URI {tracking_uri}. '
f'Please unset it or set the value to `{MLFLOW_DATABRICKS_TRACKING_URI}`.',
)
# Use the Databricks WorkspaceClient to check that credentials are set up correctly.
try:
WorkspaceClient()
except Exception as e:
raise ValueError(
f'Databricks SDK credentials not correctly setup. '
'Visit https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication '
'to identify different ways to setup credentials.',
) from e
self._mlflow_client = MlflowClient(tracking_uri)
mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set( # pyright: ignore[reportGeneralTypeIssues]
multipart_upload_chunk_size,
)
experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path)
if experiment_id == MLFLOW_EXPERIMENT_ID_PLACEHOLDER:
experiment_id = None
if run_id == MLFLOW_RUN_ID_PLACEHOLDER:
run_id = None
# Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided.
self.experiment_id, self.run_id = self._init_run_info(experiment_id, run_id)
def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> Tuple[str, str]:
"""Returns the experiment ID and run ID for the MLflow run backing this object store.
In a distributed setting, this should only be called on the rank 0 process.
"""
import mlflow
if experiment_id is None:
if run_id is not None:
raise ValueError('A `run_id` cannot be provided without a valid `experiment_id`.')
active_run = mlflow.active_run()
if active_run is not None:
experiment_id = active_run.info.experiment_id
run_id = active_run.info.run_id
log.debug(f'MLFlowObjectStore using active MLflow run {run_id=}')
else:
# If no active run exists, create a new run for the default experiment.
mlflow_env_var_name = mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name # pyright: ignore[reportGeneralTypeIssues]
experiment_name = os.getenv(mlflow_env_var_name, DEFAULT_MLFLOW_EXPERIMENT_NAME)
experiment = self._mlflow_client.get_experiment_by_name(experiment_name)
if experiment is not None:
experiment_id = experiment.experiment_id
else:
experiment_id = self._mlflow_client.create_experiment(experiment_name)
run_id = self._mlflow_client.create_run(experiment_id).info.run_id
log.debug(
f'MLFlowObjectStore using a new MLflow run {run_id=}'
f'for new experiment "{experiment_name}" {experiment_id=}',
)
else:
if run_id is not None:
# If a `run_id` is provided, check that it belongs to the provided experiment.
run = self._mlflow_client.get_run(run_id)
if run.info.experiment_id != experiment_id:
raise ValueError(
f'Provided `run_id` {run_id} does not belong to provided experiment {experiment_id}. '
f'Found experiment {run.info.experiment_id}.',
)
log.debug(
f'MLFlowObjectStore using provided MLflow run {run_id=} '
f'for provided experiment {experiment_id=}',
)
else:
# If no `run_id` is provided, create a new run in the provided experiment.
run = self._mlflow_client.create_run(experiment_id)
run_id = run.info.run_id
log.debug(
f'MLFlowObjectStore using new MLflow run {run_id=} '
f'for provided experiment {experiment_id=}',
)
if experiment_id is None or run_id is None:
raise ValueError('MLFlowObjectStore failed to initialize experiment and run ID.')
return experiment_id, run_id
[docs] @staticmethod
def parse_dbfs_path(path: str) -> Tuple[str, str, str]:
"""Parses a DBFS path to extract the MLflow experiment ID, run ID, and relative artifact path.
The path is expected to be of the format
`databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<artifact_path>`.
Args:
path (str): The DBFS path to parse.
Returns:
(str, str, str): (experiment_id, run_id, artifact_path)
Raises:
ValueError: If the path is not of the expected format.
"""
if not path.startswith(MLFLOW_DBFS_PATH_PREFIX):
raise ValueError(f'DBFS MLflow path should start with {MLFLOW_DBFS_PATH_PREFIX}. Got: {path}')
# Strip `databricks/mlflow-tracking/` and split into
# `<experiment_id>`, `<run_id>`, `'artifacts'`, `<relative_path>``
subpath = path[len(MLFLOW_DBFS_PATH_PREFIX):]
mlflow_parts = subpath.split('/', maxsplit=3)
if len(mlflow_parts) != 4 or mlflow_parts[2] != 'artifacts':
raise ValueError(
f'Databricks MLflow artifact path expected to be of the format '
f'{MLFLOW_DBFS_PATH_PREFIX}/<experiment_id>/<run_id>/artifacts/<relative_path>. '
f'Found {path=}',
)
return mlflow_parts[0], mlflow_parts[1], mlflow_parts[3]
[docs] def get_artifact_path(self, object_name: str) -> str:
"""Converts an object name into an MLflow relative artifact path.
Args:
object_name (str): The object name to convert. If the object name is a DBFS path beginning with
``MLFLOW_DBFS_PATH_PREFIX``, the path will be parsed to extract the MLflow relative artifact path.
Otherwise, the object name is assumed to be a relative artifact path and will be returned as-is.
"""
if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX):
experiment_id, run_id, object_name = self.parse_dbfs_path(object_name)
if (experiment_id != self.experiment_id and experiment_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER):
raise ValueError(
f'Object {object_name} belongs to experiment ID {experiment_id}, '
f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.',
)
if (run_id != self.run_id and run_id != MLFLOW_RUN_ID_PLACEHOLDER):
raise ValueError(
f'Object {object_name} belongs to run ID {run_id}, '
f'but MLFlowObjectStore is associated with run ID {self.run_id}.',
)
return object_name
[docs] def get_dbfs_path(self, object_name: str) -> str:
"""Converts an object name to a full DBFS path."""
artifact_path = self.get_artifact_path(object_name)
return f'{MLFLOW_DBFS_PATH_PREFIX}{self.experiment_id}/{self.run_id}/artifacts/{artifact_path}'
def get_uri(self, object_name: str) -> str:
return 'dbfs:/' + self.get_dbfs_path(object_name)
def upload_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None,
):
del callback # unused
from mlflow.exceptions import MlflowException
# Extract relative path from DBFS path.
artifact_path = self.get_artifact_path(object_name)
artifact_base_name = os.path.basename(artifact_path)
artifact_dir = os.path.dirname(artifact_path)
# Since MLflow doesn't support uploading artifacts with a different base name than the local file,
# create a temporary symlink to the local file with the same base name as the desired artifact name.
filename = os.path.abspath(filename)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_symlink_path = os.path.join(tmp_dir, artifact_base_name)
os.symlink(filename, tmp_symlink_path)
try:
self._mlflow_client.log_artifact(self.run_id, tmp_symlink_path, artifact_dir)
except MlflowException as e:
_wrap_mlflow_exceptions(self.get_uri(object_name), e)
def get_object_size(self, object_name: str) -> int:
from mlflow.exceptions import MlflowException
artifact = None
try:
artifact = self._get_artifact_info(object_name)
except MlflowException as e:
_wrap_mlflow_exceptions(self.get_uri(object_name), e)
if artifact is not None:
return artifact.file_size
else:
raise FileNotFoundError(f'Object {object_name} not found')
def download_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
) -> None:
del callback # unused
from mlflow.exceptions import MlflowException
# Since MlflowClient.download_artifacts only raises MlflowException with 500 Internal Error,
# check for existence to surface a FileNotFoundError if necessary.
artifact_path = self.get_artifact_path(object_name)
artifact_info = self._get_artifact_info(object_name)
if artifact_info is None:
raise FileNotFoundError(f'Object {self.get_dbfs_path(artifact_path)} not found')
filename = os.path.abspath(filename)
if os.path.exists(filename) and not overwrite:
raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.')
# MLFLow doesn't support downloading artifacts directly to a specified filename, so instead
# download to a temporary directory and then move the file to the desired location.
with tempfile.TemporaryDirectory() as tmp_dir:
try:
self._mlflow_client.download_artifacts(
run_id=self.run_id,
path=artifact_path,
dst_path=tmp_dir,
)
tmp_path = os.path.join(tmp_dir, artifact_path)
os.makedirs(os.path.dirname(filename), exist_ok=True)
if overwrite:
os.replace(tmp_path, filename)
else:
os.rename(tmp_path, filename)
except MlflowException as e:
_wrap_mlflow_exceptions(self.get_uri(artifact_path), e)
[docs] def list_objects(self, prefix: Optional[str] = None) -> List[str]:
"""See :meth:`~composer.utils.ObjectStore.list_objects`.
MLFlowObjectStore does not support listing objects with a prefix, so the ``prefix`` argument is ignored.
"""
del prefix # not supported for MLFlowObjectStore
objects = []
self._list_objects_helper(None, objects)
return objects
def _list_objects_helper(self, prefix: Optional[str], objects: List[str]) -> None:
"""Helper to recursively populate the full list of objects for ``list_objects``.
Args:
prefix (str | None): An artifact path prefix for artifacts to find.
objects (list[str]): The list of DBFS object paths to populate.
"""
from mlflow.exceptions import MlflowException
artifact = None
try:
for artifact in self._mlflow_client.list_artifacts(self.run_id, prefix):
if artifact.is_dir:
self._list_objects_helper(artifact.path, objects)
else:
objects.append(artifact.path)
except MlflowException as e:
uri = '' if artifact is None else self.get_uri(artifact.path)
_wrap_mlflow_exceptions(uri, e)
def _get_artifact_info(self, object_name):
"""Get the :class:`~mlflow.entities.FileInfo` for the given object name.
Args:
object_name (str): The name of the object, either as an absolute DBFS path or a relative MLflow artifact path.
Returns:
Optional[FileInfo]: The :class:`~mlflow.entities.FileInfo` for the object, or None if it does not exist.
"""
# MLflow doesn't support info for a singleton artifact, so we need to list all artifacts in the
# parent path and find the one with the matching name.
artifact_path = self.get_artifact_path(object_name)
artifact_dir = os.path.dirname(artifact_path)
artifacts = self._mlflow_client.list_artifacts(self.run_id, artifact_dir)
for artifact in artifacts:
if not artifact.is_dir and artifact.path == artifact_path:
return artifact
return None