Source code for composer.utils.compression

# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utilities for creating and loading compressed files."""

import shutil
import subprocess
from contextlib import contextmanager
from typing import IO, Iterator, List, Optional

__all__ = ['is_compressed_pt', 'CliCompressor', 'get_compressor', 'KNOWN_COMPRESSORS']


class CompressorNotFound(FileNotFoundError):
    pass


[docs]def is_compressed_pt(filename: str) -> bool: """Whether the filename is for a directly compressed pt file. Whether the extension of the given filename indicates that the file contains a raw compressed stream of a single pt file without a container (like tar). """ parts = filename.split('.') return len(parts) >= 2 and parts[-2] == 'pt'
[docs]class CliCompressor: """Base class for data compression CLI tools. This class handles compression and decompression of data by piping it through CLI compressor tools installed on the system. e.g. the `gzip` command for producing `.gz` files. Example: .. code-block:: python compressor = CliCompressor('gz', 'gzip') with compressor.compress('myfile.txt.gz') as f: f.write('foo') with compressor.decompress('myfile.txt.gz') as f: assert f.read() == 'foo' Args: extension (str): The suffix used to identify files that the compressor supports (without a leading `.`). cmd (str, optional): The name of the CLI tool that this compressor uses. Defaults to `None`, in which case it is assumed that the tool name is the same as the extension. """ def __init__(self, extension: str, cmd: Optional[str] = None) -> None: self.extension = extension self.cmd = cmd if cmd is not None else extension def __repr__(self) -> str: return f'CliCompressor({self.extension!r}, {self.cmd!r})' @property def exists(self) -> bool: """Whether the CLI tool used by this compressor can be found.""" return shutil.which(self.cmd) is not None def check_exists(self) -> None: if not self.exists: raise CompressorNotFound(f'Could not find command "{self.cmd}" in the PATH.') def _compress_cmd(self) -> List[str]: return [self.cmd]
[docs] @contextmanager def compress(self, out_filename: str) -> Iterator[IO[bytes]]: """Compress some data, saving to the given file.""" self.check_exists() with open(out_filename, 'wb') as f: proc = subprocess.Popen( self._compress_cmd(), stdin=subprocess.PIPE, stdout=f, ) assert proc.stdin is not None yield proc.stdin proc.stdin.close() returncode = proc.wait() if returncode != 0: raise IOError(f'failed to compress to "{out_filename}" using {self!r} (return code {returncode})')
def _decompress_cmd(self, filename: str) -> List[str]: return [self.cmd, '-dc', filename]
[docs] @contextmanager def decompress(self, in_filename: str) -> Iterator[IO[bytes]]: """Decompress the content of the given file, providing the output as a file-like object.""" self.check_exists() proc = subprocess.Popen( self._decompress_cmd(in_filename), stdout=subprocess.PIPE, ) assert proc.stdout is not None yield proc.stdout returncode = proc.wait() if returncode != 0: raise IOError(f'failed to decompress "{in_filename}" using {self!r} (return code {returncode})')
[docs]def get_compressor(filename: str) -> CliCompressor: """Obtain the compressor that supports the format of the given file.""" if not is_compressed_pt(filename): raise ValueError(f'The given filename does not correspond to a compressed file: "{filename}".') extension = filename.split('.')[-1] for c in KNOWN_COMPRESSORS: if c.extension == extension: return c raise CompressorNotFound(f'Could not find compressor for "{filename}".')
KNOWN_COMPRESSORS = [ CliCompressor('bz2', 'bzip2'), CliCompressor('gz', 'gzip'), CliCompressor('lz4'), CliCompressor('lzma'), CliCompressor('lzo', 'lzop'), CliCompressor('xz'), CliCompressor('zst', 'zstd'), ]