Source code for composer.utils.object_store

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utility for uploading to and downloading from cloud object stores."""
import dataclasses
import os
import sys
import tempfile
import textwrap
import uuid
from typing import Any, Dict, Iterator, Optional, Union

import yahp as hp
from libcloud.storage.providers import get_driver
from libcloud.storage.types import ObjectDoesNotExistError

__all__ = ["ObjectStoreHparams", "ObjectStore"]


[docs]@dataclasses.dataclass class ObjectStoreHparams(hp.Hparams): """:class:`~composer.utils.object_store.ObjectStore` hyperparameters. .. rubric:: Example Here's an example on how to connect to an Amazon S3 bucket. This example assumes: * The container is named named ``MY_CONTAINER``. * The AWS Access Key ID is stored in an environment variable named ``AWS_ACCESS_KEY_ID``. * The Secret Access Key is in an environmental variable named ``AWS_SECRET_ACCESS_KEY``. .. testsetup:: composer.utils.object_store.ObjectStoreHparams.__init__.s3 import os os.environ["AWS_ACCESS_KEY_ID"] = "key" os.environ["AWS_SECRET_ACCESS_KEY"] = "secret" .. doctest:: composer.utils.object_store.ObjectStoreHparams.__init__.s3 >>> from composer.utils import ObjectStoreHparams >>> provider_hparams = ObjectStoreHparams( ... provider="s3", ... container="MY_CONTAINER", ... key_environ="AWS_ACCESS_KEY_ID", ... secret_environ="AWS_SECRET_ACCESS_KEY", ... ) >>> provider = provider_hparams.initialize_object() >>> provider <composer.utils.object_store.ObjectStore object at ...> Args: provider (str): Cloud provider to use. See :class:`ObjectStore` for documentation. container (str): The name of the container (i.e. bucket) to use. key_environ (str, optional): The name of an environment variable containing the API key or username to use to connect to the provider. If no key is required, then set this field to ``None``. (default: ``None``) For security reasons, composer requires that the key be specified via an environment variable. For example, if your key is an environment variable called ``OBJECT_STORE_KEY`` that is set to ``MY_KEY``, then you should set this parameter equal to ``OBJECT_STORE_KEY``. Composer will read the key like this: .. testsetup:: composer.utils.object_store.ObjectStoreHparams.__init__.key import os import functools from composer.utils import ObjectStoreHparams os.environ["OBJECT_STORE_KEY"] = "MY_KEY" ObjectStoreHparams = functools.partial(ObjectStoreHparams, provider="s3", container="container") .. doctest:: composer.utils.object_store.ObjectStoreHparams.__init__.key >>> import os >>> params = ObjectStoreHparams(key_environ="OBJECT_STORE_KEY") >>> key = os.environ[params.key_environ] >>> key 'MY_KEY' secret_environ (str, optional): The name of an environment variable containing the API secret or password to use for the provider. If no secret is required, then set this field to ``None``. (default: ``None``) For security reasons, composer requires that the secret be specified via an environment variable. For example, if your secret is an environment variable called ``OBJECT_STORE_SECRET`` that is set to ``MY_SECRET``, then you should set this parameter equal to ``OBJECT_STORE_SECRET``. Composer will read the secret like this: .. testsetup:: composer.utils.object_store.ObjectStoreHparams.__init__.secret import os import functools from composer.utils import ObjectStoreHparams original_secret = os.environ.get("OBJECT_STORE_SECRET") os.environ["OBJECT_STORE_SECRET"] = "MY_SECRET" ObjectStoreHparams = functools.partial(ObjectStoreHparams, provider="s3", container="container") .. doctest:: composer.utils.object_store.ObjectStoreHparams.__init__.secret >>> import os >>> params = ObjectStoreHparams(secret_environ="OBJECT_STORE_SECRET") >>> secret = os.environ[params.secret_environ] >>> secret 'MY_SECRET' region (str, optional): Cloud region to use for the cloud provider. Most providers do not require the region to be specified. (default: ``None``) host (str, optional): Override the hostname for the cloud provider. (default: ``None``) port (int, optional): Override the port for the cloud provider. (default: ``None``) extra_init_kwargs (Dict[str, Any], optional): Extra keyword arguments to pass into the constructor for the specified provider. (default: ``None``, which is equivalent to an empty dictionary) .. seealso:: :class:`libcloud.storage.base.StorageDriver` """ provider: str = hp.required("Cloud provider to use.") container: str = hp.required("The name of the container (i.e. bucket) to use.") key_environ: Optional[str] = hp.optional(textwrap.dedent("""\ The name of an environment variable containing an API key or username to use to connect to the provider."""), default=None) secret_environ: Optional[str] = hp.optional(textwrap.dedent("""\ The name of an environment variable containing an API secret or password to use to connect to the provider."""), default=None) region: Optional[str] = hp.optional("Cloud region to use", default=None) host: Optional[str] = hp.optional("Override hostname for connections", default=None) port: Optional[int] = hp.optional("Override port for connections", default=None) extra_init_kwargs: Dict[str, Any] = hp.optional( "Extra keyword arguments to pass into the constructor for the specified provider.", default_factory=dict)
[docs] def get_provider_kwargs(self) -> Dict[str, Any]: """Returns the ``provider_kwargs`` argument, which is used to construct a :class:`.ObjectStore`. Returns: Dict[str, Any]: The ``provider_kwargs`` for use in constructing an :class:`.ObjectStore`. """ init_kwargs = {} for key in ("host", "port", "region"): kwarg = getattr(self, key) if getattr(self, key) is not None: init_kwargs[key] = kwarg init_kwargs["key"] = None if self.key_environ is None else os.environ[self.key_environ] init_kwargs["secret"] = None if self.secret_environ is None else os.environ[self.secret_environ] init_kwargs.update(self.extra_init_kwargs) return init_kwargs
[docs] def initialize_object(self): """Returns an instance of :class:`.ObjectStore`. Returns: ObjectStore: The object_store. """ return ObjectStore( provider=self.provider, container=self.container, provider_kwargs=self.get_provider_kwargs(), )
[docs]class ObjectStore: """Utility for uploading to and downloading from object (blob) stores, such as Amazon S3. .. rubric:: Example Here's an example for an Amazon S3 bucket named ``MY_CONTAINER``: >>> from composer.utils import ObjectStore >>> object_store = ObjectStore( ... provider="s3", ... container="MY_CONTAINER", ... provider_kwargs={ ... "key": "AKIA...", ... "secret": "*********", ... } ... ) >>> object_store <composer.utils.object_store.ObjectStore object at ...> Args: provider (str): Cloud provider to use. Valid options are: * :mod:`~libcloud.storage.drivers.atmos` * :mod:`~libcloud.storage.drivers.auroraobjects` * :mod:`~libcloud.storage.drivers.azure_blobs` * :mod:`~libcloud.storage.drivers.backblaze_b2` * :mod:`~libcloud.storage.drivers.cloudfiles` * :mod:`~libcloud.storage.drivers.digitalocean_spaces` * :mod:`~libcloud.storage.drivers.google_storage` * :mod:`~libcloud.storage.drivers.ktucloud` * :mod:`~libcloud.storage.drivers.local` * :mod:`~libcloud.storage.drivers.minio` * :mod:`~libcloud.storage.drivers.nimbus` * :mod:`~libcloud.storage.drivers.ninefold` * :mod:`~libcloud.storage.drivers.oss` * :mod:`~libcloud.storage.drivers.rgw` * :mod:`~libcloud.storage.drivers.s3` .. seealso:: :doc:`Full list of libcloud providers <libcloud:storage/supported_providers>` container (str): The name of the container (i.e. bucket) to use. provider_kwargs (Dict[str, Any], optional): Keyword arguments to pass into the constructor for the specified provider. These arguments would usually include the cloud region and credentials. Common keys are: * ``key`` (str): API key or username to be used (required). * ``secret`` (str): Secret password to be used (required). * ``secure`` (bool): Whether to use HTTPS or HTTP. Note: Some providers only support HTTPS, and it is on by default. * ``host`` (str): Override hostname used for connections. * ``port`` (int): Override port used for connections. * ``api_version`` (str): Optional API version. Only used by drivers which support multiple API versions. * ``region`` (str): Optional driver region. Only used by drivers which support multiple regions. .. seealso:: :class:`libcloud.storage.base.StorageDriver` """ def __init__(self, provider: str, container: str, provider_kwargs: Optional[Dict[str, Any]] = None) -> None: provider_cls = get_driver(provider) if provider_kwargs is None: provider_kwargs = {} self._provider = provider_cls(**provider_kwargs) self._container = self._provider.get_container(container) @property def provider_name(self): """The name of the cloud provider.""" return self._provider.name @property def container_name(self): """The name of the object storage container.""" return self._container.name
[docs] def upload_object(self, file_path: str, object_name: str, verify_hash: bool = True, extra: Optional[Dict] = None, headers: Optional[Dict[str, str]] = None): """Upload an object currently located on a disk. .. seealso:: :meth:`libcloud.storage.base.StorageDriver.upload_object`. Args: file_path (str): Path to the object on disk. object_name (str): Object name (i.e. where the object will be stored in the container.) verify_hash (bool, optional): Whether to verify hashes (default: ``True``) extra (Optional[Dict], optional): Extra attributes to pass to the underlying provider driver. (default: ``None``, which is equivalent to an empty dictionary) headers (Optional[Dict[str, str]], optional): Additional request headers, such as CORS headers. (defaults: ``None``, which is equivalent to an empty dictionary) """ self._provider.upload_object(file_path=file_path, container=self._container, object_name=object_name, extra=extra, verify_hash=verify_hash, headers=headers)
[docs] def upload_object_via_stream(self, obj: Union[bytes, Iterator[bytes]], object_name: str, extra: Optional[Dict] = None, headers: Optional[Dict[str, str]] = None): """Upload an object. .. seealso:: :meth:`libcloud.storage.base.StorageDriver.upload_object_via_stream`. Args: obj (bytes | Iterator[bytes]): The object. object_name (str): Object name (i.e. where the object will be stored in the container.) verify_hash (bool, optional): Whether to verify hashes (default: ``True``) extra (Optional[Dict], optional): Extra attributes to pass to the underlying provider driver. (default: ``None``) headers (Optional[Dict[str, str]], optional): Additional request headers, such as CORS headers. (defaults: ``None``) """ if isinstance(obj, bytes): obj = iter(i.to_bytes(1, sys.byteorder) for i in obj) self._provider.upload_object_via_stream(iterator=obj, container=self._container, object_name=object_name, extra=extra, headers=headers)
def _get_object(self, object_name: str): """Get object from object store. Recursively follow any symlinks. If an object does not exist, automatically checks if it is a symlink by appending ``.symlink``. Args: object_name (str): The name of the object. """ obj = None try: obj = self._provider.get_object(self._container.name, object_name) except ObjectDoesNotExistError: # Object not found, check for potential symlink object_name += ".symlink" obj = self._provider.get_object(self._container.name, object_name) # Recursively trace any symlinks if obj.name.endswith(".symlink"): # Download symlink object to temporary folder with tempfile.TemporaryDirectory() as tmpdir: tmppath = os.path.join(tmpdir, str(uuid.uuid4())) self._provider.download_object(obj=obj, destination_path=tmppath, overwrite_existing=True, delete_on_failure=True) # Read object name in symlink and recurse with open(tmppath) as f: symlinked_object_name = f.read() return self._get_object(symlinked_object_name) return obj
[docs] def get_object_size(self, object_name: str) -> int: """Get the size of an object, in bytes. Args: object_name (str): The name of the object. Returns: int: The object size, in bytes. """ return self._get_object(object_name).size
[docs] def download_object(self, object_name: str, destination_path: str, overwrite_existing: bool = False, delete_on_failure: bool = True): """Download an object to the specified destination path. .. seealso:: :meth:`libcloud.storage.base.StorageDriver.download_object`. Args: object_name (str): The name of the object to download. destination_path (str): Full path to a file or a directory where the incoming file will be saved. overwrite_existing (bool, optional): Set to ``True`` to overwrite an existing file. (default: ``False``) delete_on_failure (bool, optional): Set to ``True`` to delete a partially downloaded file if the download was not successful (hash mismatch / file size). (default: ``True``) """ obj = self._get_object(object_name) self._provider.download_object(obj=obj, destination_path=destination_path, overwrite_existing=overwrite_existing, delete_on_failure=delete_on_failure)
[docs] def download_object_as_stream(self, object_name: str, chunk_size: Optional[int] = None): """Return a iterator which yields object data. .. seealso:: :meth:`libcloud.storage.base.StorageDriver.download_object_as_stream`. Args: object_name (str): Object name. chunk_size (Optional[int], optional): Optional chunk size (in bytes). Returns: Iterator[bytes]: The object, as a byte stream. """ obj = self._get_object(object_name) return self._provider.download_object_as_stream(obj, chunk_size=chunk_size)