Source code for composer.loggers.remote_uploader_downloader

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log files to an object store."""

from __future__ import annotations

import logging
import multiprocessing
import os
import pathlib
import queue
import shutil
import tempfile
import threading
import time
import uuid
import warnings
from multiprocessing.context import SpawnProcess
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from urllib.parse import urlparse

import torch

from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
                            ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore,
                            dist, format_name_with_dist, get_file, retry)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
    from composer.core import State
    from composer.devices import Device

log = logging.getLogger(__name__)

__all__ = ['RemoteUploaderDownloader']


def _build_remote_backend(remote_backend_name: str, backend_kwargs: Dict[str, Any]):
    remote_backend_cls = None
    remote_backend_name_to_cls = {
        's3': S3ObjectStore,
        'oci': OCIObjectStore,
        'sftp': SFTPObjectStore,
        'libcloud': LibcloudObjectStore,
        'gs': GCSObjectStore,
    }

    # Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore`
    # or :class:`.MLFlowObjectStore`.
    if remote_backend_name == 'dbfs':
        path = backend_kwargs['path']
        if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
            remote_backend_cls = MLFlowObjectStore
        else:
            # Validate if the path conforms to the requirements for UC volume paths
            UCObjectStore.validate_path(path)
            remote_backend_cls = UCObjectStore
    else:
        remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
        if remote_backend_cls is None:
            supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs']
            raise ValueError(
                f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})'
            )

    return remote_backend_cls(**backend_kwargs)


[docs]class RemoteUploaderDownloader(LoggerDestination): r"""Logger destination that uploads (downloads) files to (from) a remote backend. This logger destination handles calls to :meth:`.Logger.upload_file` and uploads files to :class:`.ObjectStore`, such as AWS S3 or Google Cloud Storage. To minimize the training loop performance hit, it supports background uploads. .. testcode:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__ from composer.loggers import RemoteUploaderDownloader from composer.utils import LibcloudObjectStore remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri="s3://my-bucket", ) # Construct the trainer using this logger trainer = Trainer( ..., loggers=[remote_uploader_downloader], ) or .. testcode:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__ from composer.loggers import RemoteUploaderDownloader from composer.utils import LibcloudObjectStore remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri="libcloud://my-bucket", backend_kwargs={ 'provider': 's3', 'container': 'my-bucket', 'provider_kwargs': { 'key': 'AKIA...', 'secret': '*********', 'region': 'ap-northeast-1', }, }, ) # Construct the trainer using this logger trainer = Trainer( ..., loggers=[remote_uploader_downloader], ) or .. testcode:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__ from composer.loggers import RemoteUploaderDownloader from composer.trainer import Trainer remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri="libcloud://my-gcs-bucket", backend_kwargs={ "provider": "google_storage", "container": "my-gcs-bucket", "key_environ": "MY_HMAC_ACCESS_ID", # Name of env variable for HMAC access id. "secret_environ": "MY_HMAC_SECRET", # Name of env variable for HMAC secret. }, ) # Construct the trainer using this logger trainer = Trainer( ..., loggers=[remote_uploader_downloader], ) .. note:: This callback blocks the training loop to upload each file, as the uploading happens in the background. Here are some additional tips for minimizing the performance impact: * Set ``use_procs=True`` (the default) to use background processes, instead of threads, to perform the file uploads. Processes are recommended to ensure that the GIL is not blocking the training loop when performing CPU operations on uploaded files (e.g. computing and comparing checksums). Network I/O happens always occurs in the background. * Provide a RAM disk path for the ``upload_staging_folder`` parameter. Copying files to stage on RAM will be faster than writing to disk. However, there must have sufficient excess RAM, or :exc:`MemoryError`\s may be raised. Args: bucket_uri (str): The remote uri for the bucket to use (e.g. s3://my-bucket). As individual :class:`.ObjectStore` instances are not necessarily thread safe, each worker will construct its own :class:`.ObjectStore` instance from ``remote_backend`` and ``backend_kwargs``. backend_kwargs (Dict[str, Any]): The keyword arguments to construct the remote backend indicated by ``bucket_uri``. As individual :class:`.ObjectStore` instances are not necessarily thread safe, each worker will construct its own :class:`.ObjectStore` instance from ``remote_backend`` and ``backend_kwargs``. file_path_format_string (str, optional): A format string used to determine the remote file path (within the specified bucket). The following format variables are available: +------------------------+-------------------------------------------------------+ | Variable | Description | +========================+=======================================================+ | ``{remote_file_name}`` | The name of the file being logged. | +------------------------+-------------------------------------------------------+ | ``{run_name}`` | The name of the training run. See | | | :attr:`.State.run_name`. | +------------------------+-------------------------------------------------------+ | ``{rank}`` | The global rank, as returned by | | | :func:`~composer.utils.dist.get_global_rank`. | +------------------------+-------------------------------------------------------+ | ``{local_rank}`` | The local rank of the process, as returned by | | | :func:`~composer.utils.dist.get_local_rank`. | +------------------------+-------------------------------------------------------+ | ``{world_size}`` | The world size, as returned by | | | :func:`~composer.utils.dist.get_world_size`. | +------------------------+-------------------------------------------------------+ | ``{local_world_size}`` | The local world size, as returned by | | | :func:`~composer.utils.dist.get_local_world_size`. | +------------------------+-------------------------------------------------------+ | ``{node_rank}`` | The node rank, as returned by | | | :func:`~composer.utils.dist.get_node_rank`. | +------------------------+-------------------------------------------------------+ Leading slashes (``'/'``) will be stripped. Consider the following example, which subfolders the remote files by their rank: .. testsetup:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.file_path_format_string import os os.makedirs('path/to', exist_ok=True) with open('path/to/file.txt', 'w+') as f: f.write('hi') .. doctest:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.file_path_format_string >>> remote_uploader_downloader = RemoteUploaderDownloader(..., file_path_format_string='rank_{rank}/{remote_file_name}') >>> trainer = Trainer(..., save_latest_filename=None, run_name='foo', loggers=[remote_uploader_downloader]) >>> trainer.logger.upload_file( ... remote_file_name='bar.txt', ... file_path='path/to/file.txt', ... ) .. testcleanup:: composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader.__init__.file_path_format_string # Shut down the uploader remote_uploader_downloader._check_workers() remote_uploader_downloader.post_close() Assuming that the process's rank is ``0``, the remote backend would store the contents of ``'path/to/file.txt'`` in at ``'rank0/bar.txt'``. Default: ``'{remote_file_name}'`` num_concurrent_uploads (int, optional): Maximum number of concurrent uploads. Defaults to 1. upload_staging_folder (str, optional): A folder to use for staging uploads. If not specified, defaults to using a :func:`~tempfile.TemporaryDirectory`. use_procs (bool, optional): Whether to perform file uploads in background processes (as opposed to threads). Defaults to True. num_attempts (int, optional): For operations that fail with a transient error, the number of attempts to make. Defaults to 3. """ def __init__(self, bucket_uri: str, backend_kwargs: Optional[Dict[str, Any]] = None, file_path_format_string: str = '{remote_file_name}', num_concurrent_uploads: int = 1, upload_staging_folder: Optional[str] = None, use_procs: bool = True, num_attempts: int = 3) -> None: parsed_remote_bucket = urlparse(bucket_uri) self.remote_backend_name, self.remote_bucket_name = parsed_remote_bucket.scheme, parsed_remote_bucket.netloc self.backend_kwargs = backend_kwargs if backend_kwargs is not None else {} if self.remote_backend_name in ['s3', 'oci', 'gs'] and 'bucket' not in self.backend_kwargs: self.backend_kwargs['bucket'] = self.remote_bucket_name elif self.remote_backend_name == 'sftp' and 'host' not in self.backend_kwargs: self.backend_kwargs['host'] = f'sftp://{self.remote_bucket_name}' elif self.remote_backend_name == 'libcloud' and 'container' not in self.backend_kwargs: self.backend_kwargs['container'] = self.remote_bucket_name self.file_path_format_string = file_path_format_string self.num_attempts = num_attempts self._run_name = None if upload_staging_folder is None: self._tempdir = tempfile.TemporaryDirectory() self._upload_staging_folder = self._tempdir.name else: self._tempdir = None self._upload_staging_folder = upload_staging_folder if num_concurrent_uploads < 1: raise ValueError('num_concurrent_uploads must be >= 1. Blocking uploads are not supported.') self._num_concurrent_uploads = num_concurrent_uploads # There could be multiple upload workers uploading to the same object # If multiple workers are uploading to the same object simultaneously (e.g. the checkpoint latest symlink file), then # The object store might keep the earlier file rather than the latter file as the "latest" version # To work around this, each object name can appear at most once in `self._file_upload_queue` # The main separately keeps track of {file_path_format_string: tempfile_path} for each API call to self.upload_file # and then periodically transfers items from this dictionary onto the file upload queue # Lock for modifying `logged_objects` or `enqueued_objects` # These objects are used by threads on the main process only self._object_lock = threading.Lock() # Files that were logged but yet to be enqueued. Mapping of the object name to the (tempfile path, overwrite) for that object self._logged_objects: Dict[str, Tuple[str, bool]] = {} # Set of enqueued objects. This should keep track of everything in self._file_upload_queue with O(1) lookup self._enqueued_objects: Set[str] = set() # Thread that runs `self._enqueue_uploads` self._enqueue_thread = None # Event to signal the enqueue thread to shut down. self._enqueue_thread_flag = None if use_procs: mp_ctx = multiprocessing.get_context('spawn') self._file_upload_queue: Union[queue.Queue[Tuple[str, str, bool]], multiprocessing.JoinableQueue[Tuple[str, str, bool]],] = mp_ctx.JoinableQueue() self._completed_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str],] = mp_ctx.JoinableQueue() self._exception_queue: Union[queue.Queue[Exception], multiprocessing.JoinableQueue[Exception],] = mp_ctx.JoinableQueue() self._finished_cls: Union[Callable[[], multiprocessing._EventType], # pyright: ignore[reportGeneralTypeIssues] Type[threading.Event]] = mp_ctx.Event self._proc_class = mp_ctx.Process else: self._file_upload_queue = queue.Queue() self._completed_queue = queue.Queue() self._exception_queue = queue.Queue() self._finished_cls = threading.Event self._proc_class = threading.Thread self._worker_flag: Optional[Union[multiprocessing._EventType, # pyright: ignore[reportGeneralTypeIssues] threading.Event]] = None self._workers: List[Union[SpawnProcess, threading.Thread]] = [] # the object store instance for the main thread. Deferring the construction of the object_store to first use. self._remote_backend = None @property def remote_backend(self) -> ObjectStore: """The :class:`.ObjectStore` instance for the main thread.""" if self._remote_backend is None: self._remote_backend = _build_remote_backend(self.remote_backend_name, self.backend_kwargs) return self._remote_backend def init(self, state: State, logger: Logger) -> None: del logger # unused if self._worker_flag is not None: raise RuntimeError('The RemoteUploaderDownloader is already initialized.') self._worker_flag = self._finished_cls() self._run_name = state.run_name file_name_to_test = self._remote_file_name('.credentials_validated_successfully') # Create the enqueue thread self._enqueue_thread_flag = self._finished_cls() self._enqueue_thread = threading.Thread(target=self._enqueue_uploads, daemon=True) self._enqueue_thread.start() if dist.get_global_rank() == 0: retry(ObjectStoreTransientError, self.num_attempts)(lambda: _validate_credentials(self.remote_backend, file_name_to_test))() # If the remote backend is an `MLFlowObjectStore`, the original path kwarg may have placeholders that can be # updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information # must be propagated across all ranks before the workers are started so that all workers use the same # MLFlow run. if self.backend_kwargs.get('path', '').startswith(MLFLOW_DBFS_PATH_PREFIX): if dist.get_global_rank() == 0: assert isinstance(self.remote_backend, MLFlowObjectStore) self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path']) path_list = [self.backend_kwargs['path']] dist.broadcast_object_list(path_list, src=0) self.backend_kwargs['path'] = path_list[0] assert len(self._workers) == 0, 'workers should be empty if self._worker_flag was None' for _ in range(self._num_concurrent_uploads): worker = self._proc_class( target=_upload_worker, kwargs={ 'file_queue': self._file_upload_queue, 'is_finished': self._worker_flag, 'remote_backend_name': self.remote_backend_name, 'backend_kwargs': self.backend_kwargs, 'num_attempts': self.num_attempts, 'completed_queue': self._completed_queue, 'exception_queue': self._exception_queue, }, # The worker threads are joined in the shutdown procedure, so it is OK to set the daemon status # Setting daemon status prevents the process from hanging if close was never called (e.g. in doctests) daemon=True, ) worker.start() self._workers.append(worker) def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused self._check_workers() def epoch_end(self, state: State, logger: Logger) -> None: del state, logger # unused self._check_workers() @property def _all_workers_alive(self): """Whether all workers are alive.""" return all(worker.is_alive() for worker in self._workers) def _check_workers(self): # Periodically check to see if any of the upload workers crashed # They would crash if: # a) A file could not be uploaded, and the retry counter failed, or # b) overwrite=False, but the file already exists, if not self._all_workers_alive: if not self._exception_queue.empty(): exception = self._exception_queue.get_nowait() raise exception else: raise RuntimeError('Upload worker crashed. Please check the logs.') def upload_file( self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool, ): copied_path = os.path.join(self._upload_staging_folder, str(uuid.uuid4())) os.makedirs(self._upload_staging_folder, exist_ok=True) shutil.copy2(file_path, copied_path) formatted_remote_file_name = self._remote_file_name(remote_file_name) with self._object_lock: if formatted_remote_file_name in self._logged_objects and not overwrite: raise FileExistsError( f'Object {formatted_remote_file_name} was already enqueued to be uploaded, but overwrite=False.') self._logged_objects[formatted_remote_file_name] = (copied_path, overwrite)
[docs] def can_upload_files(self) -> bool: """Whether the logger supports uploading files.""" return True
def _enqueue_uploads(self): """Worker thread to enqueue uploads. This thread does two things: 1. It enqueues objects from ``self._logged_objects`` onto ``self._file_upload_queue``. 2. It keeps ``self._enqueued_objects`` in sync with ``self._file_upload_queue`` by listening to ``self._completed_uploads``. """ assert self._enqueue_thread_flag is not None while True: with self._object_lock: # Remove all objects from self._enqueued_objects that have been successfully uploaded while True: try: object_name = self._completed_queue.get_nowait() except queue.Empty: break self._enqueued_objects.remove(object_name) self._completed_queue.task_done() # Enqueue all objects that are in self._logged_objects but not in self._file_upload_queue objects_to_delete = [] for object_name, (copied_path, overwrite) in self._logged_objects.items(): if object_name in self._enqueued_objects: continue self._file_upload_queue.put_nowait((copied_path, object_name, overwrite)) objects_to_delete.append(object_name) self._enqueued_objects.add(object_name) for object_name in objects_to_delete: del self._logged_objects[object_name] # Shutdown if the enqueue thread flag is set, which means that no more objects will be added to # self._logged_objects if self._enqueue_thread_flag.is_set(): if self._all_workers_alive: if len(self._logged_objects) == 0: # If finished (i.e. no more objects to be added to self._logged_objects) and all logged objects are # enqueued, then break break else: # If any worker died, then it's impossible to recover since the file was already popped off of the queue, # so break. Some files may not be uploaded. break time.sleep(0.2) # Yield lock for `self.upload_file` def download_file( self, remote_file_name: str, destination: str, overwrite: bool = False, progress_bar: bool = True, ): get_file(path=remote_file_name, destination=destination, object_store=self.remote_backend, overwrite=overwrite, progress_bar=progress_bar) def fit_end(self, state: State, logger: Logger): self.wait_for_workers(state.device) def eval_standalone_end(self, state: State, logger: Logger): self.wait_for_workers(state.device) def predict_end(self, state: State, logger: Logger): self.wait_for_workers(state.device)
[docs] def wait_for_workers(self, device: Device): """Wait for all tasks to be completed. This is called after fit/eval/predict. If we don't wait, then a worker might not schedule an upload before the interpreter is shutdown and garbage collection begins. While post_close logic ensures existing uploads are completed, trying to schedule new uploads during this time will error. """ # Verify enqueue thread has processed all tasks unless a worker threw an exception while self._exception_queue.empty(): with self._object_lock: if len(self._logged_objects) == 0: break time.sleep(0.2) # Yield lock for enqueue thread # Verify all tasks have been completed unless a worker threw an exception all_ranks_upload_done_tensor = device.tensor_to_device( torch.tensor([int(not self._file_upload_queue.empty() and self._exception_queue.empty())], dtype=torch.uint8)) dist.all_reduce(all_ranks_upload_done_tensor, reduce_operation='MAX') upload_not_done = all_ranks_upload_done_tensor.item() == 1 while upload_not_done: time.sleep(2) all_ranks_upload_done_tensor = device.tensor_to_device( torch.tensor([int(not self._file_upload_queue.empty() and self._exception_queue.empty())], dtype=torch.uint8)) dist.all_reduce(all_ranks_upload_done_tensor, reduce_operation='MAX') upload_not_done = all_ranks_upload_done_tensor.item() == 1 if not self._exception_queue.empty(): e = self._exception_queue.get_nowait() raise e
def post_close(self): # Shutdown logic: # 1. Signal to the enqueue thread that all uploads are enqueued. Specifically. # set a flag indicating that that no more objects will be added to self._logged_objects. # 2. Wait for the enqueue thread to shut down. It will only shut down once all objects are added to # self._file_upload_queue. This will mean that self._logged_objects is empty. # 3. Send a flag to the workers that all uploads are enqueued in self._file_upload_queue. # 4. Wait for the workers to shut down. This means that all files have been uploaded if self._enqueue_thread_flag is not None: self._enqueue_thread_flag.set() if self._enqueue_thread is not None: self._enqueue_thread.join() if self._worker_flag is not None: self._worker_flag.set() # Then, ensure all workers have finished all uploads for worker in self._workers: worker.join() # Clean up the tempdir if self._tempdir is not None: self._tempdir.cleanup() # Empty the completed queue # This cleanup will not be done by the enqueue_thread anymore, as that thread has been shut down while True: try: object_name = self._completed_queue.get_nowait() except queue.Empty: break self._enqueued_objects.remove(object_name) self._completed_queue.task_done() if len(self._enqueued_objects) > 0 or len(self._logged_objects) > 0: # Warn on all objects that have not been uploaded object_names = list(self._enqueued_objects) object_names.extend(self._logged_objects.keys()) warnings.warn( RuntimeWarning('The following objects may not have been uploaded, likely due to a worker crash: ' + ', '.join(self._enqueued_objects))) # Reset all variables self._logged_objects.clear() self._enqueued_objects.clear() self._enqueue_thread = None self._tempdir = None self._worker_flag = None self._enqueue_thread_flag = None self._workers.clear()
[docs] def get_uri_for_file(self, remote_file_name: str) -> str: """Get the object store provider uri for a remote file. Args: remote_file_name (str): The name of a remote file. Returns: str: The uri corresponding to the uploaded location of the remote file. """ formatted_remote_file_name = self._remote_file_name(remote_file_name) return self.remote_backend.get_uri(formatted_remote_file_name.lstrip('/'))
def _remote_file_name(self, remote_file_name: str): """Format the ``remote_file_name`` according to the ``file_path_format_string``.""" if self._run_name is None: raise RuntimeError('The run name is not set. It should have been set on Event.INIT.') key_name = format_name_with_dist( self.file_path_format_string, run_name=self._run_name, remote_file_name=remote_file_name, ) key_name = key_name.lstrip('/') return key_name
def _validate_credentials( remote_backend: ObjectStore, remote_file_name_to_test: str, ) -> None: # Validates the credentials by attempting to touch a file in the bucket # raises an error if there was a credentials failure. with tempfile.NamedTemporaryFile('wb') as f: f.write(b'credentials_validated_successfully') remote_backend.upload_object( object_name=remote_file_name_to_test, filename=f.name, ) def _upload_worker( file_queue: Union[queue.Queue[Tuple[str, str, bool]], multiprocessing.JoinableQueue[Tuple[str, str, bool]]], completed_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]], exception_queue: Union[queue.Queue[Exception], multiprocessing.JoinableQueue[Exception]], is_finished: Union[multiprocessing._EventType, threading.Event], # pyright: ignore[reportGeneralTypeIssues] remote_backend_name: str, backend_kwargs: Dict[str, Any], num_attempts: int, ): """A long-running function to handle uploading files to the object store. The worker will continuously poll ``file_queue`` for files to upload. Once ``is_finished`` is set, the worker will exit once ``file_queue`` is empty. """ remote_backend = _build_remote_backend(remote_backend_name, backend_kwargs) while True: try: file_path_to_upload, remote_file_name, overwrite = file_queue.get(block=True, timeout=0.5) except queue.Empty: if is_finished.is_set(): break else: continue uri = remote_backend.get_uri(remote_file_name) # defining as a function-in-function to use decorator notation with num_attempts as an argument @retry(ObjectStoreTransientError, num_attempts=num_attempts) def upload_file(): if not overwrite: try: remote_backend.get_object_size(remote_file_name) except FileNotFoundError: # Good! It shouldn't exist. pass else: # Exceptions will be detected on the next batch_end or epoch_end event e = FileExistsError(f'Object {uri} already exists, but overwrite was set to False.') exception_queue.put_nowait(e) raise e log.info('Uploading file %s to %s', file_path_to_upload, uri) try: remote_backend.upload_object( object_name=remote_file_name, filename=file_path_to_upload, ) except Exception as e: exception_queue.put_nowait(e) raise e os.remove(file_path_to_upload) file_queue.task_done() completed_queue.put_nowait(remote_file_name) upload_file()