# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""Write files to remote location which can be either Cloud Storage or a local path."""
import logging
import os
import pathlib
import shutil
import sys
import urllib.parse
from enum import Enum
from tempfile import mkdtemp
from typing import Any, Optional, Union
import tqdm
from streaming.base.storage.download import (BOTOCORE_CLIENT_ERROR_CODES,
GCS_ERROR_NO_AUTHENTICATION)
from streaming.base.util import get_import_exception_message, retry
__all__ = [
'CloudUploader',
'S3Uploader',
'GCSUploader',
'OCIUploader',
'HFUploader',
'AzureUploader',
'DatabricksUnityCatalogUploader',
'DBFSUploader',
'AlipanUploader',
'LocalUploader',
]
logger = logging.getLogger(__name__)
UPLOADERS = {
's3': 'S3Uploader',
'gs': 'GCSUploader',
'oci': 'OCIUploader',
'hf': 'HFUploader',
'azure': 'AzureUploader',
'azure-dl': 'AzureDataLakeUploader',
'dbfs:/Volumes': 'DatabricksUnityCatalogUploader',
'dbfs': 'DBFSUploader',
'alipan': 'AlipanUploader',
'': 'LocalUploader',
}
class GCSAuthentication(Enum):
HMAC = 1
SERVICE_ACCOUNT = 2
[docs]class CloudUploader:
"""Upload local files to a cloud storage."""
[docs] @classmethod
def get(cls,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> Any:
"""Instantiate a cloud provider uploader or a local uploader based on remote path.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
Returns:
CloudUploader: An instance of sub-class.
"""
cls._validate(cls, out)
obj = urllib.parse.urlparse(out) if isinstance(out, str) else urllib.parse.urlparse(out[1])
provider_prefix = obj.scheme
if obj.scheme == 'dbfs':
path = pathlib.Path(out) if isinstance(out, str) else pathlib.Path(out[1])
prefix = os.path.join(path.parts[0], path.parts[1])
if prefix == 'dbfs:/Volumes':
provider_prefix = prefix
return getattr(sys.modules[__name__],
UPLOADERS[provider_prefix])(out, keep_local, progress_bar, retry, exist_ok)
def _validate(self, out: Union[str, tuple[str, str]]) -> None:
"""Validate the `out` argument.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
Raises:
ValueError: Invalid number of `out` argument.
ValueError: Invalid Cloud provider prefix.
"""
if isinstance(out, str):
obj = urllib.parse.urlparse(out)
else:
if len(out) != 2:
raise ValueError(f'Invalid `out` argument. It is either a string of ' +
f'local/remote directory or a list of two strings with ' +
f'[local, remote].')
obj = urllib.parse.urlparse(out[1])
if obj.scheme not in UPLOADERS:
raise ValueError(f'Invalid Cloud provider prefix: {obj.scheme}.')
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
"""Initialize and validate local and remote path.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
Raises:
FileExistsError: Local directory must be empty.
"""
self._validate(out)
self.keep_local = keep_local
self.progress_bar = progress_bar
self.retry = retry
if isinstance(out, str):
# It is a remote directory
if urllib.parse.urlparse(out).scheme != '':
self.local = mkdtemp()
self.remote = out
# It is a local directory
else:
self.local = out
self.remote = None
else:
self.local = out[0]
self.remote = out[1]
if os.path.exists(self.local) and len(os.listdir(self.local)) != 0:
if not exist_ok:
raise FileExistsError(f'Directory is not empty: {self.local}')
else:
logger.warning(
f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.'
)
os.makedirs(self.local, exist_ok=True)
[docs] def upload_file(self, filename: str):
"""Upload file from local instance to remote instance.
Args:
filename (str): File to upload.
Raises:
NotImplementedError: Override this method in your sub-class.
"""
raise NotImplementedError(f'{type(self).__name__}.upload_file is not implemented')
[docs] def list_objects(self, prefix: Optional[str] = None) -> Optional[list[str]]:
"""List all objects in the object store with the given prefix.
Args:
prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``.
Returns:
List[str]: A list of object names that match the prefix.
"""
raise NotImplementedError(f'{type(self).__name__}.list_objects is not implemented')
[docs] def clear_local(self, local: str):
"""Remove the local file if it is enabled.
Args:
local (str): A local file path.
"""
if not self.keep_local and os.path.isfile(local):
os.remove(local)
[docs]class S3Uploader(CloudUploader):
"""Upload file from local machine to AWS S3 bucket (or any S3 compatible object store).
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
import boto3
from botocore.config import Config
config = Config()
# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
session = boto3.session.Session()
self.s3 = session.client('s3',
config=config,
endpoint_url=os.environ.get('S3_ENDPOINT_URL'))
self.check_bucket_exists(self.remote) # pyright: ignore
[docs] def upload_file(self, filename: str):
"""Upload file from local instance to AWS S3 bucket.
Args:
filename (str): File to upload.
"""
@retry(num_attempts=self.retry)
def _upload_file():
extra_args = {}
# Check for canned ACL environment variable
canned_acl = os.environ.get('S3_CANNED_ACL')
if canned_acl is not None:
extra_args['ACL'] = canned_acl
local_filename = os.path.join(self.local, filename)
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
self.s3.upload_file(
local_filename,
obj.netloc,
obj.path.lstrip('/'),
ExtraArgs=extra_args,
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)
self.clear_local(local=local_filename)
_upload_file()
[docs] def check_bucket_exists(self, remote: str):
"""Raise an exception if the bucket does not exist.
Args:
remote (str): S3 bucket path.
Raises:
error: Bucket does not exist.
"""
from botocore.exceptions import ClientError
bucket_name = urllib.parse.urlparse(remote).netloc
try:
self.s3.head_bucket(Bucket=bucket_name)
except ClientError as error:
if error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES:
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
raise error
[docs] def list_objects(self, prefix: Optional[str] = None) -> Optional[list[str]]:
"""List all objects in the S3 object store with the given prefix.
Args:
prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``.
Returns:
List[str]: A list of object names that match the prefix.
"""
if prefix is None:
prefix = ''
obj = urllib.parse.urlparse(self.remote)
bucket_name = obj.netloc
prefix = os.path.join(str(obj.path).lstrip('/'), prefix)
paginator = self.s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
try:
return [obj['Key'] for page in pages for obj in page['Contents']]
except KeyError:
return []
[docs]class GCSUploader(CloudUploader):
"""Upload file from local machine to Google Cloud Storage bucket.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
import boto3
# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
session = boto3.session.Session()
self.gcs_client = session.client(
's3',
region_name='auto',
endpoint_url='https://storage.googleapis.com',
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'],
)
self.authentication = GCSAuthentication.HMAC
else:
from google.auth import default as default_auth
from google.auth.exceptions import DefaultCredentialsError
from google.cloud.storage import Client
try:
credentials, _ = default_auth()
self.gcs_client = Client(credentials=credentials)
self.authentication = GCSAuthentication.SERVICE_ACCOUNT
except (DefaultCredentialsError, EnvironmentError):
raise ValueError(GCS_ERROR_NO_AUTHENTICATION)
self.check_bucket_exists(self.remote) # pyright: ignore
[docs] def upload_file(self, filename: str) -> None:
"""Upload file from local instance to Google Cloud Storage bucket.
Args:
filename (str): File to upload.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
if self.authentication == GCSAuthentication.HMAC:
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(
total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar),
) as pbar:
self.gcs_client.upload_file(
local_filename,
obj.netloc,
obj.path.lstrip('/'),
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
from google.cloud.storage import Blob, Bucket
blob = Blob(obj.path.lstrip('/'), Bucket(self.gcs_client, obj.netloc))
blob.upload_from_filename(local_filename)
self.clear_local(local=local_filename)
_upload_file()
[docs] def check_bucket_exists(self, remote: str) -> None:
"""Raise an exception if the bucket does not exist.
Args:
remote (str): GCS bucket path.
Raises:
error: Bucket does not exist.
"""
bucket_name = urllib.parse.urlparse(remote).netloc
if self.authentication == GCSAuthentication.HMAC:
from botocore.exceptions import ClientError
try:
self.gcs_client.head_bucket(Bucket=bucket_name)
except ClientError as error:
if (error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES):
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
raise error
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
self.gcs_client.get_bucket(bucket_name)
[docs] def list_objects(self, prefix: Optional[str] = None) -> Optional[list[str]]:
"""List all objects in the GCS object store with the given prefix.
Args:
prefix (Optional[str], optional): The prefix to search for. Defaults to None.
Returns:
List[str]: A list of object names that match the prefix.
"""
if prefix is None:
prefix = ''
obj = urllib.parse.urlparse(self.remote)
bucket_name = obj.netloc
if self.authentication == GCSAuthentication.HMAC:
prefix = os.path.join(str(obj.path).lstrip('/'), prefix)
paginator = self.gcs_client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
try:
return [obj['Key'] for page in pages for obj in page['Contents']]
except KeyError:
return []
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
prefix = os.path.join(str(obj.path).lstrip('/'), prefix)
return [
b.name for b in self.gcs_client.get_bucket(bucket_name).list_blobs(prefix=prefix)
]
[docs]class OCIUploader(CloudUploader):
"""Upload file from local machine to Oracle Cloud Infrastructure (OCI) Cloud Storage.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
import oci
config = oci.config.from_file()
self.client = oci.object_storage.ObjectStorageClient(
config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY)
self.namespace = self.client.get_namespace().data
self.upload_manager = oci.object_storage.UploadManager(self.client)
self.check_bucket_exists(self.remote) # pyright: ignore
[docs] def upload_file(self, filename: str):
"""Upload file from local instance to OCI Cloud Storage bucket.
Args:
filename (str): File to upload.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
obj = urllib.parse.urlparse(remote_filename)
bucket_name = obj.netloc.split('@' + self.namespace)[0]
# Remove leading and trailing forward slash from string
object_path = obj.path.strip('/')
logger.debug(f'Uploading to {remote_filename}')
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
self.upload_manager.upload_file(
namespace_name=self.namespace,
bucket_name=bucket_name,
object_name=object_path,
file_path=local_filename,
progress_callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)
self.clear_local(local=local_filename)
_upload_file()
[docs] def check_bucket_exists(self, remote: str):
"""Raise an exception if the bucket does not exist.
Args:
remote (str): OCI bucket path.
Raises:
error: Bucket does not exist.
"""
from oci.exceptions import ServiceError
obj = urllib.parse.urlparse(remote)
bucket_name = obj.netloc.split('@' + self.namespace)[0]
try:
self.client.head_bucket(bucket_name=bucket_name, namespace_name=self.namespace)
except ServiceError as error:
if error.status == 404:
error.args = (f'Bucket `{bucket_name}` does not exist! ' +
f'Check the bucket permission or create the bucket.',)
raise error
[docs] def list_objects(self, prefix: Optional[str] = None) -> Optional[list[str]]:
"""List all objects in the OCI object store with the given prefix.
Args:
prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``.
Returns:
List[str]: A list of object names that match the prefix.
"""
if prefix is None:
prefix = ''
obj = urllib.parse.urlparse(self.remote)
bucket_name = obj.netloc.split('@' + self.namespace)[0]
prefix = os.path.join(str(obj.path).strip('/'), prefix)
object_names = []
next_start_with = None
response_complete = False
try:
while not response_complete:
response = self.client.list_objects(namespace_name=self.namespace,
bucket_name=bucket_name,
prefix=prefix,
start=next_start_with).data
object_names.extend([resp_obj.name for resp_obj in response.objects])
next_start_with = response.next_start_with
if not next_start_with:
response_complete = True
return object_names
except Exception as e:
if isinstance(e, oci.exceptions.ServiceError): # type: ignore
if e.status == 404: # type: ignore
if e.code == 'ObjectNotFound': # type: ignore
raise FileNotFoundError(
f'Object {bucket_name}/{prefix} not found. {e.message}' # type: ignore
) from e # type: ignore
if e.code == 'BucketNotFound': # type: ignore
raise ValueError(
f'Bucket {bucket_name} not found. {e.message}') from e # type: ignore
raise e
raise e
return []
[docs]class HFUploader(CloudUploader):
"""Upload file from local machine to a Huggingface Dataset.
Args:
out (str): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory then the shard files are uploaded to the
remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: str,
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
import huggingface_hub
self.api = huggingface_hub.HfApi()
self.fs = huggingface_hub.HfFileSystem(token=os.environ.get('HF_TOKEN', None))
obj = urllib.parse.urlparse(out)
if obj.scheme != 'hf':
raise ValueError(f'Expected remote path to start with `hf://`, got {out}.')
_, _, _, self.repo_org, self.repo_name, self.path = out.split('/', 5)
self.dataset_id = os.path.join(self.repo_org, self.repo_name)
self.check_dataset_exists() # pyright: ignore
[docs] def upload_file(self, filename: str):
"""Upload file from local instance to HF.
Args:
filename (str): File to upload.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = filename
local_filename = local_filename.replace('\\', '/')
remote_filename = os.path.join('datasets', self.dataset_id, filename)
remote_filename = remote_filename.replace('\\', '/')
logger.debug(f'Uploading to {remote_filename}')
with self.fs.open(remote_filename, 'wb') as f:
with open(local_filename, 'rb') as data:
f.write(data.read())
_upload_file()
[docs] def check_dataset_exists(self):
"""Raise an exception if the dataset does not exist.
Raises:
error: Dataset does not exist.
"""
import huggingface_hub
try:
_ = list(huggingface_hub.list_repo_tree(self.dataset_id, repo_type='dataset'))
except Exception:
raise FileNotFoundError(
f'The HF dataset {self.dataset_id} could not be found. Please make sure ' +
f'that the dataset exists and you have the correct access permissions.')
[docs]class AzureUploader(CloudUploader):
"""Upload file from local machine to Microsoft Azure bucket.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
from azure.storage.blob import BlobServiceClient
# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
self.azure_service = BlobServiceClient(
account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net",
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'],
)
self.check_bucket_exists(self.remote) # pyright: ignore
[docs] def upload_file(self, filename: str):
"""Upload file from local instance to Microsoft Azure bucket.
Args:
filename (str): File to upload.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
local_filename = local_filename.replace('\\', '/')
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
remote_filename = remote_filename.replace('\\', '/')
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
file_size = os.stat(local_filename).st_size
container_client = self.azure_service.get_container_client(container=obj.netloc)
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
with open(local_filename, 'rb') as data:
container_client.upload_blob(
name=obj.path.lstrip('/'),
data=data,
progress_hook=lambda bytes_transferred, _: pbar.update(bytes_transferred),
overwrite=True)
self.clear_local(local=local_filename)
_upload_file()
[docs] def check_bucket_exists(self, remote: str):
"""Raise an exception if the bucket does not exist.
Args:
remote (str): azure bucket path.
Raises:
error: Bucket does not exist.
"""
bucket_name = urllib.parse.urlparse(remote).netloc
if self.azure_service.get_container_client(container=bucket_name).exists() is False:
raise FileNotFoundError(
f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
[docs]class AzureDataLakeUploader(CloudUploader):
"""Upload file from local machine to Microsoft Azure DataLake.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
from azure.storage.filedatalake import DataLakeServiceClient
# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
self.azure_service = DataLakeServiceClient(
account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.dfs.core.windows.net",
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'])
self.check_container_exists(self.remote) # pyright: ignore
[docs] def upload_file(self, filename: str):
"""Upload file from local instance to Azure DataLalke container.
Args:
filename (str): File to upload.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
local_filename = local_filename.replace('\\', '/')
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
remote_filename = remote_filename.replace('\\', '/')
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
file_size = os.stat(local_filename).st_size
file_client = self.azure_service.get_file_client(file_system=obj.netloc,
file_path=obj.path.lstrip('/'))
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
with open(local_filename, 'rb') as data:
file_client.upload_data(data=data, overwrite=True)
pbar.update(file_size)
self.clear_local(local=local_filename)
_upload_file()
[docs] def check_container_exists(self, remote: str):
"""Raise an exception if the container does not exist.
Args:
remote (str): azure container path.
Raises:
error: Container does not exist.
"""
container_name = urllib.parse.urlparse(remote).netloc
if self.azure_service.get_file_system_client(file_system=container_name).exists() is False:
raise FileNotFoundError(
f'Either container `{container_name}` does not exist! ' +
f'or check the container permission.',)
class DatabricksUploader(CloudUploader):
"""Parent class for uploading files from local machine to a Databricks workspace.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
self.client = self._create_workspace_client()
def _create_workspace_client(self):
try:
from databricks.sdk import WorkspaceClient
return WorkspaceClient()
except ImportError as e:
e.msg = get_import_exception_message(e.name, 'databricks') # pyright: ignore
raise e
class DatabricksUnityCatalogUploader(DatabricksUploader):
"""Upload file from local machine to Databricks Unity Catalog.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
def upload_file(self, filename: str):
"""Upload file from local instance to Databricks Unity Catalog.
Args:
filename (str): Relative filepath to copy.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
local_filename = local_filename.replace('\\', '/')
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
remote_filename = remote_filename.replace('\\', '/')
remote_filename_wo_prefix = urllib.parse.urlparse(remote_filename).path
file_size = os.stat(local_filename).st_size
with open(local_filename, 'rb') as f:
self.client.files.upload(remote_filename_wo_prefix, f, overwrite=True)
# Warning!
# filesAPI.upload fails silently when failed to upload!
# need to manually check HEAD and throw exception to retry
metadata = self.client.files.get_metadata(remote_filename_wo_prefix)
if file_size != metadata.content_length:
raise RuntimeError(f'Uploading failed! {file_size}!= {metadata.content_length}')
_upload_file()
class DBFSUploader(DatabricksUploader):
"""Upload file from local machine to Databricks File System (DBFS).
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore
self.check_folder_exists()
def upload_file(self, filename: str):
"""Upload file from local instance to DBFS. Does not overwrite.
Args:
filename (str): Relative filepath to copy.
"""
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
local_filename = local_filename.replace('\\', '/')
remote_filename = os.path.join(self.dbfs_path, filename)
remote_filename = remote_filename.replace('\\', '/')
file_path = urllib.parse.urlparse(remote_filename)
with open(local_filename, 'rb') as f:
self.client.dbfs.upload(file_path.path, f)
_upload_file()
def check_folder_exists(self):
"""Raise an exception if the DBFS folder does not exist.
Raises:
error: Folder does not exist.
"""
from databricks.sdk.core import DatabricksError
try:
if not self.client.dbfs.exists(self.dbfs_path):
raise FileNotFoundError(f'Databricks File System path {self.dbfs_path} not found')
except DatabricksError as e:
if e.error_code == 'PERMISSION_DENIED':
e.args = (
f'Ensure the file path or credentials are set correctly. For ' +
f'Databricks Unity Catalog, file path must starts with `dbfs:/Volumes` ' +
f'and for Databricks File System, file path must starts with `dbfs`. ' +
e.args[0],)
raise e
class AlipanUploader(CloudUploader):
"""Upload file from local machine to Alipan.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
obj = urllib.parse.urlparse(self.remote)
if obj.scheme != 'alipan':
raise ValueError(
f'Expected obj.scheme to be `alipan`, instead, got {obj.scheme} for remote={self.remote}'
)
if obj.netloc != '':
raise ValueError(
f'Expected remote to be alipan:///path/to/some, instead, got remote={self.remote}')
from alipcs_py.alipcs import AliPCSApiMix
from alipcs_py.commands.upload import EncryptType
web_refresh_token = os.environ['ALIPAN_WEB_REFRESH_TOKEN']
web_token_type = 'Bearer'
self.alipan_encrypt_password = os.environ.get('ALIPAN_ENCRYPT_PASSWORD', '').encode()
self.alipan_encrypt_type = EncryptType.No # No encryption by default
encrypt_type = os.environ.get('ALIPAN_ENCRYPT_TYPE', '')
if encrypt_type:
if getattr(EncryptType, encrypt_type, None) is None:
raise ValueError(
f'Invalid ALIPAN_ENCRYPT_TYPE: {encrypt_type}. ' +
'Encryption type must be one of `Simple`, `ChaCha20`, `AES256CBC`')
else:
self.alipan_encrypt_type = getattr(EncryptType, encrypt_type)
self.api = AliPCSApiMix(web_refresh_token, web_token_type=web_token_type)
self.check_token()
def check_token(self):
"""Raise an exception if the refresh_token is invalid.
Raises:
error: AliPCSError with code `AccessTokenInvalid`
"""
self.api.refresh()
def upload_file(self, filename: str):
"""Upload file from local instance to Alipan directory.
Args:
filename (str): File to upload.
"""
from alipcs_py.commands import upload as alipcs_upload
from alipcs_py.commands.upload import EncryptType, total_len
@retry(num_attempts=self.retry)
def _upload_file():
local_filename = os.path.join(self.local, filename)
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
if self.alipan_encrypt_type == EncryptType.No:
file_size = os.stat(local_filename).st_size
else:
file_size = self.alipan_encrypt_type
encrypt_io = self.alipan_encrypt_type.encrypt_io(open(local_filename, 'rb'),
self.alipan_encrypt_password)
file_size = total_len(encrypt_io)
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
alipcs_upload.upload_file(
self.api, (local_filename, obj.path),
'overwrite',
encrypt_password=self.alipan_encrypt_password,
encrypt_type=self.alipan_encrypt_type,
callback_for_monitor=lambda offset: pbar.update(offset - pbar.n))
self.clear_local(local=local_filename)
_upload_file()
def list_objects(self, prefix: Optional[str] = None) -> list[str]:
"""List all objects in the remote path with the given prefix.
Args:
prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``.
Returns:
List[str]: A list of object names that match the prefix.
"""
if prefix is None:
prefix = ''
if self.remote is None:
raise ValueError('Alipan remote path must be set.')
obj = urllib.parse.urlparse(self.remote)
remote_pcs_file = self.api.get_file(remotepath=obj.path)
if remote_pcs_file is None:
raise FileNotFoundError(f'Alipan remote path `{obj.path}` not found.')
file_paths = []
for pcs_file in self.api.list_iter(remote_pcs_file.file_id, recursive=True):
file_path = pathlib.PosixPath(obj.path, pcs_file.path).as_posix()
if file_path.startswith(prefix):
file_paths.append(file_path)
return sorted(file_paths)
[docs]class LocalUploader(CloudUploader):
"""Copy file from one local directory to another local directory.
Args:
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in
the `local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset
shard file or remove it after uploading. Defaults to ``False``.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""
def __init__(self,
out: Union[str, tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
# Create remote directory if it doesn't exist
if self.remote:
os.makedirs(self.remote, exist_ok=True)
[docs] def upload_file(self, filename: str):
"""Copy file from one local path to another local path.
Args:
filename (str): Relative filepath to copy.
"""
@retry(num_attempts=self.retry)
def _upload_file():
if self.remote:
local_filename = os.path.join(self.local, filename)
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
logger.debug(f'Copying to {remote_filename}')
shutil.copy(local_filename, remote_filename)
self.clear_local(local=local_filename)
_upload_file()
[docs] def list_objects(self, prefix: Optional[str] = None) -> list[str]:
"""List all objects locally with the given prefix.
Args:
prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``.
Returns:
List[str]: A list of object names that match the prefix.
"""
if prefix is None:
prefix = ''
prefix = os.path.join(self.local, prefix)
file_paths = []
for dirpath, _, files in os.walk(self.local):
for file in files:
file_path = os.path.join(dirpath, file)
if file_path.startswith(prefix):
file_paths.append(file_path)
return sorted(file_paths)