Source code for composer.utils.object_store.oci_object_store

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""OCI-Compatible object store."""

from __future__ import annotations

import concurrent.futures
import os
import pathlib
import uuid
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Union

from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.object_store.object_store import ObjectStore

__all__ = ['OCIObjectStore']


def _reraise_oci_errors(uri: str, e: Exception):
    try:
        import oci
    except ImportError as e:
        raise MissingConditionalImportError(
            conda_package='oci',
            extra_deps_group='oci',
            conda_channel='conda-forge',
        ) from e

    # If it's an oci service error with code: ObjectNotFound or status 404
    if isinstance(e, oci.exceptions.ServiceError):
        if e.status == 404:  # type: ignore
            if e.code == 'ObjectNotFound':  # type: ignore
                raise FileNotFoundError(f'Object {uri} not found. {e.message}') from e  # type: ignore
            if e.code == 'BucketNotFound':  # type: ignore
                raise ValueError(f'Bucket specified in {uri} not found. {e.message}') from e  # type: ignore
            raise FileNotFoundError(f'Object {uri} not found with no error code. {e.message}') from e  # type: ignore

    # Client errors
    if isinstance(e, oci.exceptions.ClientError):
        raise ValueError(f'Error with using your OCI config file for uri {uri}') from e
    if isinstance(e, oci.exceptions.MultipartUploadError):
        raise ValueError(f'Error when uploading {uri} using OCI parallelized uploading') from e

    # Otherwise just raise the original error.
    raise e


[docs]class OCIObjectStore(ObjectStore): """Utility for uploading to and downloading from an OCI bucket. Args: bucket (str): The bucket name. prefix (str): A path prefix such as `folder/subfolder/` to prepend to object names. Defaults to ''. """ def __init__( self, bucket: str, prefix: str = '', ) -> None: try: import oci except ImportError as e: raise MissingConditionalImportError( conda_package='oci', extra_deps_group='oci', conda_channel='conda-forge', ) from e # Format paths self.bucket = bucket.strip('/') self.prefix = prefix.strip('/') if self.prefix != '': self.prefix += '/' try: if 'OCI_CONFIG_FILE' in os.environ: config = oci.config.from_file(os.environ['OCI_CONFIG_FILE']) else: config = oci.config.from_file() self.client = oci.object_storage.ObjectStorageClient( config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY, ) except Exception as e: _reraise_oci_errors(self.get_uri(object_name=''), e) self.namespace = self.client.get_namespace().data # pyright: ignore[reportOptionalMemberAccess] self.upload_manager = oci.object_storage.UploadManager(self.client) def get_uri(self, object_name: str) -> str: return f'oci://{self.bucket}/{object_name}' def get_object_size(self, object_name: str) -> int: try: response = self.client.get_object( namespace_name=self.namespace, bucket_name=self.bucket, object_name=object_name, ) except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) if response.status == 200: # pyright: ignore[reportUnboundVariable, reportOptionalMemberAccess] data = response.data # pyright: ignore[reportUnboundVariable, reportOptionalMemberAccess] return int(data.headers['Content-Length']) else: status = response.status # pyright: ignore[reportUnboundVariable, reportOptionalMemberAccess] raise ValueError(f'OCI get_object was not successful with a {status} status code.') def upload_object( self, object_name: str, filename: Union[str, pathlib.Path], callback: Optional[Callable[[int, int], None]] = None, ): del callback try: self.upload_manager.upload_file( namespace_name=self.namespace, bucket_name=self.bucket, object_name=object_name, file_path=filename, ) except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) def _download_part(self, object_name, filename, start_byte, end_byte, part_number): range_header = f'bytes={start_byte}-{end_byte}' tmp_part_path = os.path.join(filename, f'part-{part_number}-{uuid.uuid4()}.tmp') response = self.client.get_object( namespace_name=self.namespace, bucket_name=self.bucket, object_name=object_name, range=range_header, ) with open(tmp_part_path, 'wb') as f: f.write(response.data.content) # pyright: ignore[reportOptionalMemberAccess] return part_number, tmp_part_path def download_object( self, object_name: str, filename: Union[str, pathlib.Path], overwrite: bool = False, callback: Optional[Callable[[int, int], None]] = None, min_part_size: int = 128000000, num_parts: int = 10, ): del callback if os.path.exists(filename) and not overwrite: raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False') dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) # Get the size of the object object_size = 0 try: head_object_response = self.client.head_object(self.namespace, self.bucket, object_name) object_size = int( head_object_response.headers['content-length'], # pyright: ignore[reportOptionalMemberAccess] ) except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) # Calculate the part sizes num_parts_from_size = max(object_size // min_part_size, 1) num_parts = min(num_parts, num_parts_from_size) base_part_size, remainder = divmod(object_size, num_parts) part_sizes = [base_part_size] * num_parts for i in range(remainder): part_sizes[i] += 1 part_sizes = [part_size for part_size in part_sizes if part_size > 0] with TemporaryDirectory(dir=dirname, prefix=f'{str(filename)}') as temp_dir: parts = [] try: # Download parts in parallel with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] start_byte = 0 for i, part_size in enumerate(part_sizes): end_byte = start_byte + part_size - 1 futures.append( executor.submit(self._download_part, object_name, temp_dir, start_byte, end_byte, i), ) start_byte = end_byte + 1 for future in concurrent.futures.as_completed(futures): parts.append(future.result()) parts = sorted(parts, key=lambda x: x[0]) except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) # Combine parts tmp_path = os.path.join(temp_dir, f'{str(filename)}-{uuid.uuid4()}.tmp') with open(tmp_path, 'wb') as outfile: for i, part_file_name in parts: with open(part_file_name, 'rb') as infile: outfile.write(infile.read()) if overwrite: os.replace(tmp_path, filename) else: os.rename(tmp_path, filename) def list_objects(self, prefix: Optional[str] = None) -> List[str]: if prefix is None: prefix = '' if self.prefix: prefix = f'{self.prefix}{prefix}' object_names = [] next_start_with = None response_complete = False try: while not response_complete: response = self.client.list_objects( namespace_name=self.namespace, bucket_name=self.bucket, prefix=prefix, start=next_start_with, ).data # pyright: ignore[reportOptionalMemberAccess] object_names.extend([obj.name for obj in response.objects]) next_start_with = response.next_start_with if not next_start_with: response_complete = True except Exception as e: _reraise_oci_errors(self.get_uri(prefix), e) return object_names