Source code for composer.utils.object_store.libcloud_object_store

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utility for uploading to and downloading from cloud object stores."""
import io
import os
import pathlib
import uuid
from typing import Any, Callable, Dict, List, Optional, Union

from requests.exceptions import ConnectionError
from urllib3.exceptions import ProtocolError

from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError

__all__ = ['LibcloudObjectStore']

[docs]class LibcloudObjectStore(ObjectStore): """Utility for uploading to and downloading from object (blob) stores, such as Amazon S3. .. rubric:: Example Here's an example for an Amazon S3 bucket named ``MY_CONTAINER``: >>> from composer.utils import LibcloudObjectStore >>> object_store = LibcloudObjectStore( ... provider="s3", ... container="MY_CONTAINER", ... provider_kwargs={ ... "key": "AKIA...", ... "secret": "*********", ... } ... ) >>> object_store <composer.utils.object_store.libcloud_object_store.LibcloudObjectStore object at ...> Args: provider (str): Cloud provider to use. Valid options are: * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` * :mod:`` .. seealso:: :doc:`Full list of libcloud providers <libcloud:storage/supported_providers>` container (str): The name of the container (i.e. bucket) to use. provider_kwargs (Dict[str, Any], optional): Keyword arguments to pass into the constructor for the specified provider. These arguments would usually include the cloud region and credentials. Common keys are: * ``key`` (str): API key or username to be used (required). * ``secret`` (str): Secret password to be used (required). * ``secure`` (bool): Whether to use HTTPS or HTTP. Note: Some providers only support HTTPS, and it is on by default. * ``host`` (str): Override hostname used for connections. * ``port`` (int): Override port used for connections. * ``api_version`` (str): Optional API version. Only used by drivers which support multiple API versions. * ``region`` (str): Optional driver region. Only used by drivers which support multiple regions. .. seealso:: :class:`` key_environ (str, optional): Environment variable name for the API Key. Only used if 'key' is not in ``provider_kwargs``. Default: None. secret_environ (str, optional): Envrionment varaible for the Secret password. Only used if 'secret' is not in ``provider_kwargs``. Default: None. """ def __init__(self, provider: str, container: str, chunk_size: int = 1_024 * 1_024, key_environ: Optional[str] = None, secret_environ: Optional[str] = None, provider_kwargs: Optional[Dict[str, Any]] = None) -> None: try: from import get_driver except ImportError as e: raise MissingConditionalImportError('libcloud', 'apache-libcloud') from e provider_cls = get_driver(provider) if provider_kwargs is None: provider_kwargs = {} if 'key' not in provider_kwargs and \ key_environ and key_environ in os.environ: provider_kwargs['key'] = os.environ[key_environ] if 'secret' not in provider_kwargs and \ secret_environ and secret_environ in os.environ: provider_kwargs['secret'] = os.environ[secret_environ] self.chunk_size = chunk_size self._provider_name = provider self._provider = provider_cls(**provider_kwargs) self._container = self._provider.get_container(container) def get_uri(self, object_name: str): return f'{self._provider_name}://{}/{object_name}' def upload_object( self, object_name: str, filename: Union[str, pathlib.Path], callback: Optional[Callable[[int, int], None]] = None, ): with open(filename, 'rb') as f: stream = iterate_with_callback(_file_to_iterator(f, self.chunk_size), os.fstat(f.fileno()).st_size, callback) try: self._provider.upload_object_via_stream( stream, container=self._container, object_name=object_name, ) except Exception as e: self._ensure_transient_errors_are_wrapped(e) def _ensure_transient_errors_are_wrapped(self, exc: Exception): from libcloud.common.types import LibcloudError if isinstance(exc, (LibcloudError, ProtocolError, TimeoutError, ConnectionError)): if isinstance(exc, LibcloudError): # The S3 driver does not encode the error code in an easy-to-parse manner # So first checking if the error code is non-transient is_transient_error = any(x in str(exc) for x in ('408', '409', '425', '429', '500', '503', '504')) if not is_transient_error: raise exc raise ObjectStoreTransientError() from exc raise exc def _get_object(self, object_name: str): """Get object from object store. Args: object_name (str): The name of the object. """ from import ObjectDoesNotExistError try: return self._provider.get_object(, object_name) except ObjectDoesNotExistError as e: raise FileNotFoundError(f'Object not found: {self.get_uri(object_name)}') from e except Exception as e: self._ensure_transient_errors_are_wrapped(e) def get_object_size(self, object_name: str) -> int: obj = self._get_object(object_name) assert obj is not None return obj.size def download_object( self, object_name: str, filename: Union[str, pathlib.Path], overwrite: bool = False, callback: Optional[Callable[[int, int], None]] = None, ): if os.path.exists(filename) and not overwrite: # If the file already exits, short-circuit and skip the download raise FileExistsError(f'filename {filename} exists and overwrite was set to False.') dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) obj = self._get_object(object_name) # Download first to a tempfile, and then rename, in case if the file gets corrupted in transit tmp_filepath = str(filename) + f'.{uuid.uuid4()}.tmp' try: with open(tmp_filepath, 'wb+') as f: assert obj is not None stream = self._provider.download_object_as_stream(obj, chunk_size=self.chunk_size) for chunk in iterate_with_callback(stream, obj.size, callback): f.write(chunk) except Exception as e: # The download failed for some reason. Make a best-effort attempt to remove the temporary file. try: os.remove(tmp_filepath) except OSError: pass self._ensure_transient_errors_are_wrapped(e) # The download was successful. if overwrite: os.replace(tmp_filepath, filename) else: os.rename(tmp_filepath, filename) def list_objects(self, prefix: Optional[str] = None) -> List[str]: if prefix is None: prefix = '' return [ for obj in self._provider.list_container_objects(self._container, prefix=prefix)]
def _file_to_iterator(f: io.IOBase, chunk_size: int): while True: byte = if byte == b'': break yield byte