Source code for composer.utils.object_store.uc_object_store

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Databricks Unity Catalog Volumes object store."""

from __future__ import annotations

import json
import logging
import os
import pathlib
import uuid
from typing import Callable, List, Optional

from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError

log = logging.getLogger(__name__)

__all__ = ['UCObjectStore']

_NOT_FOUND_ERROR_CODE = 'NOT_FOUND'


def _wrap_errors(uri: str, e: Exception):
    from databricks.sdk.core import DatabricksError
    from databricks.sdk.errors.mapping import NotFound
    if isinstance(e, DatabricksError):
        if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE:  # type: ignore
            raise FileNotFoundError(f'Object {uri} not found') from e
    raise ObjectStoreTransientError from e


[docs]class UCObjectStore(ObjectStore): """Utility class for uploading and downloading data from Databricks Unity Catalog (UC) Volumes. .. note:: Using this object store requires setting `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables with the right credentials to be able to access the files in the unity catalog volumes. Args: path (str): The Databricks UC Volume path that is of the format `Volumes/<catalog-name>/<schema-name>/<volume-name>/path/to/folder`. Note that this prefix should always start with /Volumes and adhere to the above format since this object store only suports Unity Catalog Volumes and not other Databricks Filesystems. """ _UC_VOLUME_LIST_API_ENDPOINT = '/api/2.0/fs/list' _UC_VOLUME_FILES_API_ENDPOINT = '/api/2.0/fs/files' def __init__(self, path: str) -> None: try: from databricks.sdk import WorkspaceClient except ImportError as e: raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e try: self.client = WorkspaceClient() except Exception as e: raise ValueError( f'Databricks SDK credentials not correctly setup. ' 'Visit https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication ' 'to identify different ways to setup credentials.') from e self.prefix = self.validate_path(path) self.client = WorkspaceClient()
[docs] @staticmethod def validate_path(path: str) -> str: """Parses the given path to extract the UC Volume prefix from the path. .. note:: This function only uses the first 4 directories from the path to construct the UC Volumes prefix and will ignore the rest of the directories in the path Args: path (str): The Databricks UC Volume path of the format `Volumes/<catalog-name>/<schema-name>/<volume-name>/path/to/folder`. """ path = os.path.normpath(path) if not path.startswith('Volumes'): raise ValueError('Databricks Unity Catalog Volumes paths should start with "Volumes".') dirs = path.split(os.sep) if len(dirs) < 4: raise ValueError(f'Databricks Unity Catalog Volumes path expected to be of the format ' '`Volumes/<catalog-name>/<schema-name>/<volume-name>/<optional-path>`. ' f'Found path={path}') # The first 4 dirs form the prefix return os.path.join(*dirs[:4])
def _get_object_path(self, object_name: str) -> str: """Return the absolute Single Path Namespace for the given object_name. Args: object_name (str): Absolute or relative path of the object w.r.t. the UC Volumes root. """ # convert object name to relative path if prefix is included if os.path.commonprefix([object_name, self.prefix]) == self.prefix: object_name = os.path.relpath(object_name, start=self.prefix) return os.path.join('/', self.prefix, object_name)
[docs] def get_uri(self, object_name: str) -> str: """Returns the URI for ``object_name``. .. note:: This function does not check that ``object_name`` is in the object store. It computes the URI statically. Args: object_name (str): The object name. Returns: str: The URI for ``object_name`` in the object store. """ return f'dbfs:{self._get_object_path(object_name)}'
[docs] def upload_object(self, object_name: str, filename: str | pathlib.Path, callback: Callable[[int, int], None] | None = None) -> None: """Upload a file from local to UC volumes. Args: object_name (str): Name of the stored object in UC volumes w.r.t. volume root. filename (str | pathlib.Path): Path the the object on disk callback ((int, int) -> None, optional): Unused """ # remove unused variable del callback with open(filename, 'rb') as f: self.client.files.upload(self._get_object_path(object_name), f)
[docs] def download_object(self, object_name: str, filename: str | pathlib.Path, overwrite: bool = False, callback: Callable[[int, int], None] | None = None) -> None: """Download the given object from UC Volumes to the specified filename. Args: object_name (str): The name of the object to download i.e. path relative to the root of the volume. filename (str | pathlib.Path): The local path where a the file needs to be downloaded. overwrite(bool, optional): Whether to overwrite an existing file at ``filename``, if it exists. (default: ``False``) callback ((int) -> None, optional): Unused Raises: FileNotFoundError: If the file was not found in UC volumes. ObjectStoreTransientError: If there was any other error querying the Databricks UC volumes that should be retried. """ # remove unused variable del callback if os.path.exists(filename) and not overwrite: raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.') dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) tmp_path = str(filename) + f'{uuid.uuid4()}.tmp' try: from databricks.sdk.core import DatabricksError try: contents = self.client.files.download(self._get_object_path(object_name)).contents assert contents is not None with contents as resp: # pyright: ignore with open(tmp_path, 'wb') as f: # Chunk the data into multiple blocks of 64MB to avoid # OOMs when downloading really large files for chunk in iter(lambda: resp.read(64 * 1024 * 1024), b''): f.write(chunk) except DatabricksError as e: _wrap_errors(self.get_uri(object_name), e) except: # Make best effort attempt to clean up the temporary file try: os.remove(tmp_path) except OSError: pass raise else: if overwrite: os.replace(tmp_path, filename) else: os.rename(tmp_path, filename)
[docs] def get_object_size(self, object_name: str) -> int: """Get the size of the object in UC volumes in bytes. Args: object_name (str): The name of the object. Returns: int: The object size, in bytes. Raises: FileNotFoundError: If the file was not found in the object store. IsADirectoryError: If the object is a directory, not a file. """ from databricks.sdk.core import DatabricksError try: # Note: The UC team is working on changes to fix the files.get_status API, but it currently # does not work. Once fixed, we will call the files API endpoint. We currently only use this # function in Composer and LLM-foundry to check the UC object's existence. object_path = self._get_object_path(object_name).lstrip('/') path = os.path.join(self._UC_VOLUME_FILES_API_ENDPOINT, object_path) self.client.api_client.do(method='HEAD', path=path, headers={'Source': 'mosaicml/composer'}) return 1000000 # Dummy value, as we don't have a way to get the size of the file except DatabricksError as e: # If the code reaches here, the file was not found _wrap_errors(self.get_uri(object_name), e) return -1
[docs] def list_objects(self, prefix: Optional[str]) -> List[str]: """List all objects in the object store with the given prefix. Args: prefix (str): The prefix to search for. Returns: list[str]: A list of object names that match the prefix. """ if not prefix: prefix = self.prefix from databricks.sdk.core import DatabricksError try: # NOTE: This API is in preview and should not be directly used outside of this instance logging.warn('UCObjectStore.list_objects is experimental.') # Iteratively get all UC Volume files with `prefix`. stack = [prefix] all_files = [] while len(stack) > 0: current_path = stack.pop() # Note: Databricks SDK handles HTTP errors and retries. # See https://github.com/databricks/databricks-sdk-py/blob/v0.18.0/databricks/sdk/core.py#L125 and # https://github.com/databricks/databricks-sdk-py/blob/v0.18.0/databricks/sdk/retries.py#L33 . resp = self.client.api_client.do(method='GET', path=self._UC_VOLUME_LIST_API_ENDPOINT, data=json.dumps({'path': self._get_object_path(current_path)}), headers={'Source': 'mosaicml/composer'}) assert isinstance(resp, dict), 'Response is not a dictionary' for f in resp.get('files', []): fpath = f['path'] if f['is_dir']: stack.append(fpath) else: all_files.append(fpath) return all_files except DatabricksError as e: _wrap_errors(self.get_uri(prefix), e) return []