# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Abstract class for utilities that upload to and download from object stores."""
import abc
import pathlib
from types import TracebackType
from typing import Callable, Optional, Type, Union
__all__ = ['ObjectStore', 'ObjectStoreTransientError']
[docs]class ObjectStoreTransientError(RuntimeError):
"""Custom exception class to signify transient errors.
Implementations of the :class:`.ObjectStore` should re-raise any transient exceptions
(e.g. too many requests, temporarily unavailable) with this class, so callers can easily
detect whether they should attempt to retry any operation.
For example, the :class:`.S3ObjectStore` does the following:
.. testcode::
from composer.utils import ObjectStore, ObjectStoreTransientError
import botocore.exceptions
class S3ObjectStore(ObjectStore):
def upload_object(self, file_path: str, object_name: str):
try:
...
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'LimitExceededException':
raise ObjectStoreTransientError(e.response['Error']['Code']) from e
raise e
Then, callers can automatically handle exceptions:
.. testcode::
import time
from composer.utils import ObjectStore, ObjectStoreTransientError
def upload_file(object_store: ObjectStore, max_num_attempts: int = 3):
for i in range(max_num_attempts):
try:
object_store.upload_object(...)
except ObjectStoreTransientError:
if i + 1 == max_num_attempts:
raise
else:
# Try again after exponential back-off
time.sleep(2**i)
else:
# upload successful
return
"""
pass
[docs]class ObjectStore(abc.ABC):
"""Abstract class for implementing object stores, such as LibcloudObjectStore and S3ObjectStore."""
[docs] def get_uri(self, object_name: str) -> str:
"""Returns the URI for ``object_name``.
.. note::
This function does not check that ``object_name`` is in the object store.
It computes the URI statically.
Args:
object_name (str): The object name.
Returns:
str: The URI for ``object_name`` in the object store.
"""
raise NotImplementedError(f'{type(self).__name__}.get_uri is not implemented')
[docs] def upload_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None,
**kwargs,
) -> None:
"""Upload an object currently located on a disk.
Args:
object_name (str): Object name (where object will be stored in the container)
filename (str | pathlib.Path): Path to the object on disk
callback ((int, int) -> None, optional): If specified, the callback is periodically called with the number of bytes
uploaded and the total size of the object being uploaded.
**kwargs: other arguments to the upload object function are supported
and will be passed in to the underlying object store upload call.
Currently only used for S3ObjectStore.
Raises:
ObjectStoreTransientError: If there was a transient connection issue with uploading the object.
"""
del object_name, filename, callback, kwargs # unused
raise NotImplementedError(f'{type(self).__name__}.upload_object is not implemented')
[docs] def get_object_size(self, object_name: str) -> int:
"""Get the size of an object, in bytes.
Args:
object_name (str): The name of the object.
Returns:
int: The object size, in bytes.
Raises:
FileNotFoundError: If the file was not found in the object store.
ObjectStoreTransientError: If there was a transient connection issue with getting the object size.
"""
raise NotImplementedError(f'{type(self).__name__}.get_object_size is not implemented')
[docs] def download_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
) -> None:
"""Download an object to the specified destination path.
Args:
object_name (str): The name of the object to download.
filename (str | pathlib.Path): Full path to a file or a directory where the incoming file will be saved.
overwrite (bool, optional): Whether to overwrite an existing file at ``filename``, if it exists.
(default: ``False``)
callback ((int) -> None, optional): If specified, the callback is periodically called with the number of bytes already
downloaded and the total size of the object.
Raises:
FileExistsError: If ``filename`` already exists and ``overwrite`` is ``False``.
FileNotFoundError: If the file was not found in the object store.
ObjectStoreTransientError: If there was a transient connection issue with downloading the object.
"""
del object_name, filename, overwrite, callback # unused
raise NotImplementedError(f'{type(self).__name__}.download_object is not implemented')
[docs] def list_objects(self, prefix: Optional[str]) -> list[str]:
"""List all objects in the object store with the given prefix.
Args:
prefix (str): The prefix to search for.
Returns:
list[str]: A list of object names that match the prefix.
"""
del prefix # unused
raise NotImplementedError(f'{type(self).__name__}.list_objects is not implemented')
[docs] def close(self):
"""Close the object store."""
pass
def __enter__(self):
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
):
del exc_type, exc, traceback # unused
self.close()