Source code for mcli.utils.utils_config

"""Utils for modifying MCLI Configs"""
import copy
import logging
import warnings
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, fields
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

import yaml
from typing_extensions import TypedDict

from mcli.api.exceptions import ValidationError
from mcli.utils.utils_logging import WARN, str_presenter
from mcli.utils.utils_string_functions import camel_case_to_snake_case, snake_case_to_camel_case
from mcli.utils.utils_yaml import load_yaml

logger = logging.getLogger(__name__)


def strip_nones(d: Dict[str, Any]) -> Dict[str, Any]:
    """Remove all keys with None values from a dictionary"""
    return {k: v for k, v in d.items() if v is not None}


[docs]class SchedulingConfig(TypedDict, total=False): """Typed dictionary for nested scheduling configurations""" priority: Optional[str] resumable: Optional[bool] # TODO: deprecate resumable preemptible: Optional[bool] retryOnSystemFailure: Optional[bool] max_retries: Optional[int] retry_on_system_failure: Optional[bool] max_duration: Optional[float]
[docs]class ComputeConfig(TypedDict, total=False): """Typed dictionary for nested compute requests""" cluster: Optional[str] instance: Optional[str] nodes: Optional[int] node_names: Optional[List[str]] gpu_type: Optional[str] gpus: Optional[int] cpus: Optional[int]
class MLflowConfig(TypedDict, total=False): """Typed dictionary for nested MLflow configs""" tracking_uri: Optional[str] experiment_path: str model_registry_path: Optional[str] class WandbConfig(TypedDict, total=True): """Typed dictionary for nested W&B configs""" project: str entity: str class ExperimentTrackerConfig(TypedDict, total=False): """Typed dictionary for nested experiment tracker configs""" mlflow: Optional[MLflowConfig] wandb: Optional[WandbConfig] @dataclass class BaseSubmissionConfig(): """ Base class for config objects""" _required_display_properties = set() @classmethod def empty(cls): return cls() @classmethod def from_file(cls, path: Union[str, Path]): """Load the config from the provided YAML file. Args: path (Union[str, Path]): Path to YAML file Returns: BaseSubmissionConfig: The BaseSubmissionConfig object specified in the YAML file """ config = load_yaml(path) return cls.from_dict(config, show_unused_warning=True) def to_file(self, path: Union[str, Path]): """Save the config to the provided YAML file. Args: path (Union[str, Path]): Path to YAML file """ with open(path, 'w', encoding='utf8') as f: f.write(str(self)) @classmethod def from_dict(cls, dict_to_use: Dict[str, Any], show_unused_warning: bool = False): """Load the config from the provided dictionary. Args: dict_to_use (Dict[str, Any]): The dictionary to populate the BaseSubmissionConfig with Returns: BaseSubmissionConfig: The BaseSubmissionConfig object specified in the dictionary """ field_names = list(map(lambda x: x.name, fields(cls))) unused_keys = [] constructor = {} for key, value in dict_to_use.items(): if key in field_names: constructor[key] = value else: unused_keys.append(key) if len(unused_keys) > 0 and show_unused_warning: if {'model', 'train_data_path'}.issubset(unused_keys): logger.info( f'{WARN} You specified a model, train_data_path, and cluster. Did you mean to use `mcli finetune`?') else: # pylint: disable=line-too-long warnings.warn( f'! Encountered unknown fields {", ".join(unused_keys)} which were not used in creating the request' ) return cls(**constructor) def __str__(self) -> str: filtered_dict = {} for k, v in asdict(self).items(): # skip nested and direct empty values for optional properties if k not in self._required_display_properties: if isinstance(v, dict) and not any(v.values()): continue if not v: continue filtered_dict[k] = v # to use with safe_dump: yaml.representer.SafeRepresenter.add_representer(str, str_presenter) return yaml.safe_dump(filtered_dict, default_flow_style=False, sort_keys=False).strip() T = TypeVar('T') # pylint: disable=invalid-name U = TypeVar('U') class Translation(ABC, Generic[T, U]): """ABC for MAPI/MCLI translations""" @classmethod @abstractmethod def to_mapi(cls, value: T) -> U: ... @classmethod @abstractmethod def from_mapi(cls, value: U) -> T: ... class EnvVarTranslation: """Translate environment variable configs""" MAPI_KEY = 'envKey' MAPI_VALUE = 'envValue' @classmethod def to_mapi(cls, value: Dict[str, str]) -> List[Dict[str, str]]: return [{cls.MAPI_KEY: key, cls.MAPI_VALUE: val} for key, val in value.items()] @classmethod def from_mapi(cls, value: List[Dict[str, str]]) -> Dict[str, str]: env_vars = {} for env_var in value: try: key = env_var[cls.MAPI_KEY] val = env_var[cls.MAPI_VALUE] except KeyError: logger.warning(f'Received incompatible environment variable: {env_var}') continue env_vars[key] = val return env_vars class IntegrationTranslation(Translation[List[Dict[str, Any]], List[Dict[str, Any]]]): """Translate integration configs""" MAPI_TYPE = 'type' MAPI_PARAMS = 'params' MCLI_TYPE = 'integration_type' @classmethod def to_mapi(cls, value: List[Dict[str, Any]]) -> List[Dict[str, Any]]: value = copy.deepcopy(value) integrations_list = [] if not isinstance(value, list) or not all(isinstance(item, dict) for item in value): raise ValidationError(f'Integrations input must be a list of dictionary, received: {value}') for integration in value: if cls.MCLI_TYPE not in integration and cls.MAPI_TYPE not in integration: raise ValidationError(f'Integration missing required key: {cls.MCLI_TYPE}, received {integration}') elif cls.MAPI_TYPE in integration: integration_type = integration.pop(cls.MAPI_TYPE) else: integration_type = integration.pop(cls.MCLI_TYPE) translated_integration = {} for param, val in integration.items(): # Translate keys to camel case for MAPI parameters translated_key = snake_case_to_camel_case(param) translated_integration[translated_key] = val integrations_dict = {cls.MAPI_TYPE: integration_type, cls.MAPI_PARAMS: translated_integration} integrations_list.append(integrations_dict) return integrations_list @classmethod def from_mapi(cls, value: List[Dict[str, Any]]) -> List[Dict[str, Any]]: integrations_list = [] for integration in value: translated_integration = {cls.MCLI_TYPE: integration[cls.MAPI_TYPE]} params = integration.get(cls.MAPI_PARAMS, {}) for param, val in params.items(): # Translate keys to camel case for MAPI parameters translated_key = camel_case_to_snake_case(param) translated_integration[translated_key] = val integrations_list.append(translated_integration) return integrations_list class ComputeTranslation(Translation[ComputeConfig, Dict[str, Any]]): """Translate compute configs to and from MAPI""" translations = { "gpuType": "gpu_type", "nodeNames": "node_names", } @classmethod def from_mapi(cls, value: Dict[str, Any]) -> ComputeConfig: extracted = ComputeConfig(**{cls.translations.get(k, k): v for k, v in value.items()}) return extracted @classmethod def to_mapi(cls, value: ComputeConfig) -> Dict[str, Any]: inv = {v: k for k, v in cls.translations.items()} processed = {inv.get(k, k): v for k, v in value.items() if v is not None} return processed class SchedulingTranslation(Translation[SchedulingConfig, Dict[str, Any]]): """Translate scheduling configs to and from MAPI""" translations = { "maxRetries": "max_retries", "retryOnSystemFailure": "retry_on_system_failure", "maxDuration": "max_duration" } @classmethod def from_mapi(cls, value: Dict[str, Any]) -> SchedulingConfig: extracted = SchedulingConfig(**{cls.translations.get(k, k): v for k, v in value.items() if k != "priorityInt"}) return extracted @classmethod def to_mapi(cls, value: SchedulingConfig) -> Dict[str, Any]: inv = {v: k for k, v in cls.translations.items()} processed = {inv.get(k, k): v for k, v in value.items() if v is not None} return processed class DependentDeploymentConfig(Translation, Generic[T]): """Basic translation for dependent deployment configs""" @classmethod def to_mapi(cls, value: Dict[str, Any]) -> Dict[str, Any]: translated_config = {} for key, val in value.items(): if key == 'env_variables': val = EnvVarTranslation.to_mapi(val) elif isinstance(val, dict): # This purposefully goes 2 levels deep and not further # due to how the inference server expects the config new_dict = {} for k, v in val.items(): new_dict[snake_case_to_camel_case(k)] = v val = new_dict mapi_key = snake_case_to_camel_case(key) translated_config[mapi_key] = val return translated_config @classmethod def from_mapi(cls, value: Dict[str, Any]) -> Dict[str, Any]: translated_config = {} for key, val in value.items(): if key == 'envVariables': val = EnvVarTranslation.from_mapi(val) elif isinstance(val, dict): new_dict = {} for k, v in val.items(): new_dict[camel_case_to_snake_case(k)] = v val = new_dict mapi_key = camel_case_to_snake_case(key) translated_config[mapi_key] = val return translated_config class ExperimentTrackerTranslation(Translation[ExperimentTrackerConfig, Dict[str, Any]]): """Translate scheduling configs to and from MAPI""" translations = { "mlflow": { 'trackingUri': 'tracking_uri', 'experimentPath': 'experiment_path', 'modelRegistryPath': 'model_registry_path' }, "wandb": { 'project': 'project', 'entity': 'entity' } } @classmethod def from_mapi(cls, value: Dict[str, Any]) -> ExperimentTrackerConfig: extracted = ExperimentTrackerConfig() for tracker_name, tracker_config in value.items(): extracted[tracker_name] = { cls.translations.get(tracker_name, {}).get(k, k): v for k, v in dict(tracker_config).items() } return extracted @classmethod def to_mapi(cls, value: ExperimentTrackerConfig) -> Dict[str, Any]: out = {} for tracker_name, tracker_config in value.items(): inv = {v: k for k, v in cls.translations.get(tracker_name, {}).items()} assert isinstance(tracker_config, dict) out[tracker_name] = {inv.get(k, k): v for k, v in tracker_config.items() if v is not None} return out