Source code for composer.utils.string_enum

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Base class for Enums containing string values."""
from __future__ import annotations

import textwrap
import warnings
from enum import Enum


[docs]class StringEnum(Enum): """Base class for Enums containing string values. This class enforces that all keys are uppercase and all values are lowercase. It also offers the following convenience features: * ``StringEnum(value)`` will perform a case-insensitive match on both the keys and value, and is a no-op if given an existing instance of the class. .. testsetup:: import warnings warnings.filterwarnings(action="ignore", message="Detected comparision between a string") .. doctest:: >>> from composer.utils import StringEnum >>> class MyStringEnum(StringEnum): ... KEY = "value" >>> MyStringEnum("KeY") # case-insensitive match on the key <MyStringEnum.KEY: 'value'> >>> MyStringEnum("VaLuE") # case-insensitive match on the value <MyStringEnum.KEY: 'value'> >>> MyStringEnum(MyStringEnum.KEY) # no-op if given an existing instance <MyStringEnum.KEY: 'value'> .. testcleanup:: warnings.resetwarnings() * Equality checks support case-insensitive comparisions against strings: .. testsetup:: import warnings warnings.filterwarnings(action="ignore", message="Detected comparision between a string") .. doctest:: >>> from composer.utils import StringEnum >>> class MyStringEnum(StringEnum): ... KEY = "value" >>> MyStringEnum.KEY == "KeY" # case-insensitive match on the key True >>> MyStringEnum.KEY == "VaLuE" # case-insensitive match on the value True >>> MyStringEnum.KEY == "something else" False .. testcleanup:: warnings.resetwarnings() """ __hash__ = Enum.__hash__ # pyright: ignore[reportGeneralTypeIssues] def __eq__(self, other: object) -> bool: if isinstance(other, str): cls_name = self.__class__.__name__ warnings.warn( f"Detected comparision between a string and {cls_name}. Please use {cls_name}('{other}') " f'to convert both types to {cls_name} before comparing.', category=UserWarning) try: o_enum = type(self)(other) except ValueError: # `other` is not a valid enum option return NotImplemented return super().__eq__(o_enum) return super().__eq__(other) def __init__(self, *args: object) -> None: if self.name.upper() != self.name: raise ValueError( textwrap.dedent(f"""\ {self.__class__.__name__}.{self.name} is invalid. All keys in {self.__class__.__name__} must be uppercase. To fix, rename to '{self.name.upper()}'.""")) if self.value.lower() != self.value: raise ValueError( textwrap.dedent(f"""\ The value for {self.__class__.__name__}.{self.name}={self.value} is invalid. All values in {self.__class__.__name__} must be lowercase. " To fix, rename to '{self.value.lower()}'.""")) @classmethod def _missing_(cls, value: object) -> StringEnum: # Override _missing_ so both lowercase and uppercase names are supported, # as well as passing an instance through if isinstance(value, cls): return value if isinstance(value, str): try: return cls[value.upper()] except KeyError: if value.lower() != value: return cls(value.lower()) raise ValueError(f'Value {value} not found in {cls.__name__}') raise TypeError(f'Unable to convert value({value}) of type {type(value)} into {cls.__name__}')