# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""List of Compression and Decompression algorithms."""
import bz2
import gzip
from abc import ABC, abstractmethod
from typing import Iterator, Optional
import brotli
import snappy
import zstd
from typing_extensions import Self
__all__ = [
'compress', 'decompress', 'get_compression_extension', 'get_compressions', 'is_compression'
]
class Compression(ABC):
"""A compression algorithm family."""
extension: str = '' # Filename extension.
@classmethod
def each(cls) -> Iterator[tuple[str, Self]]:
"""Get each instance of this compression algorithm family.
Returns:
Iterator[Tuple[str, Self]]: Each level.
"""
yield cls.extension, cls()
@abstractmethod
def compress(self, data: bytes) -> bytes:
"""Compress arbitrary data.
Args:
data (bytes): Uncompressed data.
Returns:
bytes: Compressed data.
"""
raise NotImplementedError
@abstractmethod
def decompress(self, data: bytes) -> bytes:
"""Decompress data compressed by this algorithm.
Args:
data (bytes): Compressed data.
Returns:
bytes: Decompressed data.
"""
raise NotImplementedError
class LevelledCompression(Compression):
"""Compression with levels.
Args:
level (int, optional): Compression level. Defaults to ``None``.
"""
levels: list = [] # Compression levels.
def __init__(self, level: Optional[int] = None) -> None:
raise NotImplementedError
@classmethod
def each(cls) -> Iterator[tuple[str, Self]]:
yield cls.extension, cls()
for level in cls.levels:
yield f'{cls.extension}:{level}', cls(level)
class Brotli(LevelledCompression):
"""Brotli compression."""
extension = 'br'
levels = list(range(12))
def __init__(self, level: int = 11) -> None:
assert level in self.levels
self.level = level
def compress(self, data: bytes) -> bytes:
return brotli.compress(data, quality=self.level)
def decompress(self, data: bytes) -> bytes:
return brotli.decompress(data)
class Bzip2(LevelledCompression):
"""Bzip2 compression."""
extension = 'bz2'
levels = list(range(1, 10))
def __init__(self, level: int = 9) -> None:
assert level in self.levels
self.level = level
def compress(self, data: bytes) -> bytes:
return bz2.compress(data, self.level)
def decompress(self, data: bytes) -> bytes:
return bz2.decompress(data)
class Gzip(LevelledCompression):
"""Gzip compression."""
extension = 'gz'
levels = list(range(10))
def __init__(self, level: int = 9) -> None:
assert level in self.levels
self.level = level
def compress(self, data: bytes) -> bytes:
return gzip.compress(data, self.level)
def decompress(self, data: bytes) -> bytes:
return gzip.decompress(data)
class Snappy(Compression):
"""Snappy compression."""
extension = 'snappy'
def compress(self, data: bytes) -> bytes:
return snappy.compress(data)
def decompress(self, data: bytes) -> bytes:
return snappy.decompress(data)
class Zstandard(LevelledCompression):
"""Zstandard compression."""
extension = 'zstd'
levels = list(range(1, 23))
def __init__(self, level: int = 3) -> None:
assert level in self.levels
self.level = level
def compress(self, data: bytes) -> bytes:
return zstd.compress(data, self.level)
def decompress(self, data: bytes) -> bytes:
return zstd.decompress(data)
# Compression algorithm families (extension -> class).
_families: dict[str, type[Compression]] = {
'br': Brotli,
'bz2': Bzip2,
'gz': Gzip,
'snappy': Snappy,
'zstd': Zstandard,
}
def _collect(families: dict[str, type[Compression]]) -> dict[str, Compression]:
"""Get each level of each compression type and flatten into a single dict.
Args:
Dict[str, Type[Compression]]: Mapping of extension to class.
Returns:
Dict[str, Compression]: Mapping of extension:level to instance.
"""
algos = {}
for cls in families.values():
for algo, obj in cls.each():
algos[algo] = obj
return algos
# Compression algorithms (extension:level -> instance).
_algorithms: dict[str, Compression] = _collect(_families)
[docs]def get_compressions() -> set[str]:
"""List supported compression algorithms.
Returns:
Set[str]: Compression algorithms.
"""
return set(_algorithms)
[docs]def is_compression(algo: Optional[str]) -> bool:
"""Get whether this compression algorithm is supported.
Args:
algo (str, optional): Compression.
Returns:
bool: Whether supported.
"""
return algo in _algorithms
[docs]def get_compression_extension(algo: str) -> str:
"""Get compressed filename extension.
Args:
algo (str): Compression.
Returns:
str: Filename extension.
"""
if not is_compression(algo):
raise ValueError(f'{algo} is not a supported compression algorithm.')
obj = _algorithms[algo]
return obj.extension
[docs]def compress(algo: Optional[str], data: bytes) -> bytes:
"""Compress arbitrary data.
Args:
algo (str, optional): Compression.
data (bytes): Uncompressed data.
Returns:
bytes: Compressed data.
"""
if algo is None:
return data
if not is_compression(algo):
raise ValueError(f'{algo} is not a supported compression algorithm.')
obj = _algorithms[algo]
return obj.compress(data)
[docs]def decompress(algo: Optional[str], data: bytes) -> bytes:
"""Decompress data compressed by this algorithm.
Args:
algo (str, optional): Compression.
data (bytes): Compressed data.
Returns:
bytes: Decompressed data.
"""
if algo is None:
return data
if not is_compression(algo):
raise ValueError(f'{algo} is not a supported compression algorithm.')
obj = _algorithms[algo]
return obj.decompress(data)