"""Global Singleton Config Store"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
import ruamel.yaml
import yaml
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from mcli.utils.utils_yaml import StringDumpYAML
logging.getLogger('urllib3.connectionpool').disabled = True
logger = logging.getLogger(__name__)
def env_path_override_config(config_value: str):
if config_value in os.environ:
globals()[config_value] = Path(os.environ[config_value])
def env_str_override_config(config_value: str):
if config_value in os.environ:
globals()[config_value] = os.environ[config_value]
MCLI_CONFIG_DIR: Path = Path(os.path.expanduser('~/.mosaic'))
env_path_override_config('MCLI_CONFIG_DIR')
MOSAICML_API_ENDPOINT: str = 'https://api.mosaicml.com/graphql'
MOSAICML_API_ENDPOINT_STAGING: str = 'https://staging.api.mosaicml.com/graphql'
MOSAICML_API_ENDPOINT_DEV: str = 'https://dev.api.mosaicml.com/graphql'
MOSAICML_API_ENDPOINT_LOCAL: str = 'http://localhost:3001/graphql'
MOSAICML_API_ENDPOINT_ENV: str = 'MOSAICML_API_ENDPOINT'
DATABRICKS_API_ENDPOINT_STAGING: str = 'https://oregon.staging.cloud.databricks.com/api/2.0/genai-mapi/graphql'
env_str_override_config(MOSAICML_API_ENDPOINT_ENV)
MOSAICML_MINT_ENDPOINT: str = 'wss://mint.mosaicml.com/v1/shell'
MOSAICML_MINT_ENDPOINT_STAGING: str = 'wss://staging.mint.mosaicml.com/v1/shell'
MOSAICML_MINT_ENDPOINT_DEV: str = 'wss://dev.mint.mosaicml.com/v1/shell'
MOSAICML_MINT_ENDPOINT_LOCAL: str = 'ws://localhost:3004/v1/shell'
MOSAICML_MINT_ENDPOINT_ENV: str = 'MOSAICML_MINT_ENDPOINT'
env_str_override_config(MOSAICML_MINT_ENDPOINT_ENV)
MCLI_CONFIG_PATH: Path = MCLI_CONFIG_DIR / 'mcli_config'
env_path_override_config('MCLI_CONFIG_PATH')
UPDATE_CHECK_FREQUENCY_DAYS: float = 2
MCLI_MODE_ENV: str = 'MCLI_MODE'
env_str_override_config(MCLI_MODE_ENV)
MCLI_TIMEOUT_ENV = 'MCLI_TIMEOUT'
env_str_override_config(MCLI_TIMEOUT_ENV)
MCLI_DISABLE_UPGRADE_CHECK_ENV: str = 'MCLI_DISABLE_UPGRADE_CHECK'
env_str_override_config(MCLI_DISABLE_UPGRADE_CHECK_ENV)
# Used for local dev and testing
MOSAICML_API_KEY_ENV: str = 'MOSAICML_API_KEY'
MOSAICML_ACCESS_TOKEN_FILE_ENV: str = 'MOSAICML_ACCESS_TOKEN_FILE'
ADMIN_MODE = False
def get_timeout(default_timeout: Optional[float] = None) -> Optional[float]:
timeout_env = os.environ.get(MCLI_TIMEOUT_ENV)
if timeout_env:
return float(timeout_env)
return default_timeout
class FeatureFlag(Enum):
"""Enum for mcli feature flags
"""
ALPHA_TESTER = 'ALPHA_TESTER'
@staticmethod
def get_external_features() -> Set[FeatureFlag]:
return set()
class MCLIMode(Enum):
"""Enum for mcli user modes
"""
PROD = 'PROD'
DEV = 'DEV'
LOCAL = 'LOCAL'
STAGING = 'STAGING'
DBX_AWS_STAGING = 'DBX_AWS_STAGING'
def is_internal(self) -> bool:
"""True if this mode is an internal mode
"""
internal_modes = {MCLIMode.DEV, MCLIMode.LOCAL, MCLIMode.STAGING, MCLIMode.DBX_AWS_STAGING}
return self in internal_modes
def available_feature_flags(self) -> List[FeatureFlag]:
if self.is_internal():
# All features are available to internal users
return list(FeatureFlag)
return list(FeatureFlag.get_external_features())
@classmethod
def from_env(cls) -> MCLIMode:
"""If the user's mcli mode is set in the environment, return it
"""
found_mode = os.environ.get(MCLI_MODE_ENV, None)
if found_mode:
found_mode = found_mode.upper()
for mode in MCLIMode:
if found_mode == mode.value:
return mode
return MCLIMode.PROD
@property
def endpoint(self) -> str:
"""The MAPI endpoint value for the given environment
"""
if self is MCLIMode.DEV:
return MOSAICML_API_ENDPOINT_DEV
elif self is MCLIMode.LOCAL:
return MOSAICML_API_ENDPOINT_LOCAL
elif self is MCLIMode.STAGING:
return MOSAICML_API_ENDPOINT_STAGING
elif self is MCLIMode.DBX_AWS_STAGING:
return DATABRICKS_API_ENDPOINT_STAGING
return MOSAICML_API_ENDPOINT
@property
def mint_endpoint(self) -> str:
"""The MINT endpoint value for the given environment
"""
if self is MCLIMode.DEV:
return MOSAICML_MINT_ENDPOINT_DEV
elif self is MCLIMode.LOCAL:
return MOSAICML_MINT_ENDPOINT_LOCAL
elif self is MCLIMode.STAGING:
return MOSAICML_MINT_ENDPOINT_STAGING
return MOSAICML_MINT_ENDPOINT
def is_alternate(self) -> bool:
"""True if the mode is a valid alternate mcloud environment
"""
alternate_env_modes = {MCLIMode.DEV, MCLIMode.LOCAL, MCLIMode.STAGING}
return self in alternate_env_modes
[docs]@dataclass
class MCLIConfig:
"""Global Config Store persisted on local disk"""
MOSAICML_API_KEY: str = '' # pylint: disable=invalid-name Global Stored within Singleton
feature_flags: Dict[str, bool] = field(default_factory=dict)
last_update_check: datetime = field(default_factory=datetime.now)
# MCloud environments w/ API keys
# Most users will be in PROD, so this will likely only be touched internally
mcloud_envs: Dict[str, str] = field(default_factory=dict)
_user_id: Optional[str] = None
_organization_id: Optional[str] = None
@property
def user_id(self):
# User id is only relevant in admin mode. If using normal mcli, it should always
# set to be blank and the user just needs to authenticate through their api key
if ADMIN_MODE:
return self._user_id
return None
@user_id.setter
def user_id(self, value: Optional[str]):
self._user_id = value
@property
def organization_id(self):
if ADMIN_MODE:
return self._organization_id
return None
@organization_id.setter
def organization_id(self, value: Optional[str]):
self._organization_id = value
def update_entity(
self,
variables: Dict[str, Any],
*,
set_user: bool = True,
set_org: bool = True,
):
if not ADMIN_MODE:
return
set_user &= (self.user_id is not None)
set_org &= (self.organization_id is not None)
if not set_user and not set_org:
return
variables['entity'] = variables.get('entity', {})
if set_user:
variables['entity']['userIds'] = [self.user_id]
if set_org:
variables['entity']['organizationIds'] = [self.organization_id]
logger.info(f'Making mapi query with entity {variables["entity"]}')
@classmethod
def empty(cls) -> MCLIConfig:
conf = MCLIConfig()
return conf
@property
def internal(self) -> bool:
return self.mcli_mode.is_internal()
@property
def mcli_mode(self) -> MCLIMode:
return MCLIMode.from_env()
@property
def disable_upgrade(self) -> bool:
disable_env = os.environ.get(MCLI_DISABLE_UPGRADE_CHECK_ENV, 'false').lower()
return disable_env == 'true'
@property
def endpoint(self) -> str:
"""The user's MAPI endpoint
"""
env_endpoint = os.environ.get(MOSAICML_API_ENDPOINT_ENV, None)
return env_endpoint or self.mcli_mode.endpoint
@property
def mint_endpoint(self) -> str:
"""The user's MINT endpoint
"""
env_endpoint = os.environ.get(MOSAICML_MINT_ENDPOINT_ENV, None)
return env_endpoint or self.mcli_mode.mint_endpoint
@property
def api_key(self):
"""The user's configured MCloud API key
"""
return self.get_api_key(env_override=True)
@property
def access_token(self):
access_token_file = os.environ.get(MOSAICML_ACCESS_TOKEN_FILE_ENV, None)
if access_token_file:
with open(access_token_file, 'r', encoding='UTF-8') as f:
access_token = f.read()
return access_token
return ''
@api_key.setter
def api_key(self, value: str):
if self.mcli_mode.is_alternate():
# If the user is using an alternative mcloud, set that API key
self.mcloud_envs[self.mcli_mode.value] = value
else:
self.MOSAICML_API_KEY = value
def get_api_key(self, env_override: bool = True):
"""Get the user's current API key
Args:
env_override (bool, optional): If True, allow an environment variable to
override the configured value, otherwise pull only from the user's config
file. Defaults to True.
Returns:
str: The user's API key, if set, otherwise an empty string
"""
api_key_env = os.environ.get(MOSAICML_API_KEY_ENV, None)
if api_key_env is not None and env_override:
return api_key_env
elif self.mcli_mode.is_alternate():
return self.mcloud_envs.get(self.mcli_mode.value, '')
elif self.MOSAICML_API_KEY:
return self.MOSAICML_API_KEY
return ''
def to_dict(self) -> Dict[str, Any]:
"""Converts the config to a dictionary
Returns:
Dict[str, Any]: The dictionary representation of the config
"""
res: Dict[str, Any] = {
'last_update_check': self.last_update_check,
}
# Only add configs if they are filled
if self.MOSAICML_API_KEY:
res['MOSAICML_API_KEY'] = self.MOSAICML_API_KEY
if self.feature_flags:
res['feature_flags'] = self.feature_flags
if self.mcloud_envs:
res['mcloud_envs'] = self.mcloud_envs
if self._user_id:
res['_user_id'] = self._user_id
if self._organization_id:
res['_organization_id'] = self._organization_id
return res
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> MCLIConfig:
# Remove any unknown or false feature flags
known_feature_flags = {f.value for f in FeatureFlag}
feature_flags = {k: v for k, v in data.get('feature_flags', {}).items() if k in known_feature_flags and v}
if isinstance(data.get('last_update_check'), str):
last_update_check = datetime.fromisoformat(data['last_update_check'])
elif isinstance(data.get('last_update_check'), datetime):
last_update_check = data['last_update_check']
else:
last_update_check = datetime.now()
return MCLIConfig(
MOSAICML_API_KEY=data.get('MOSAICML_API_KEY', ''),
feature_flags=feature_flags,
last_update_check=last_update_check,
mcloud_envs=data.get('mcloud_envs', {}),
_user_id=data.get('_user_id', None),
_organization_id=data.get('_organization_id', None),
)
@classmethod
def load_config(cls) -> MCLIConfig:
"""Loads the MCLIConfig from local disk
Return:
Returns the MCLIConfig, if not found, returns a new empty config
"""
try:
with open(MCLI_CONFIG_PATH, 'r', encoding='utf8') as f:
data: Dict[str, Any] = yaml.full_load(f)
conf = cls.from_dict(data)
except FileNotFoundError:
conf = MCLIConfig.empty()
return conf
def save_config(self) -> bool:
"""Saves the MCLIConfig to local disk
Return:
Returns true if successful
"""
logger.debug(f'Saving config to {MCLI_CONFIG_PATH}')
data = self._get_formatted_dump()
y = YAML()
y.explicit_start = True # type: ignore
os.makedirs(os.path.dirname(MCLI_CONFIG_PATH), exist_ok=True)
with open(MCLI_CONFIG_PATH, 'w', encoding='utf8') as f:
y.dump(data, f)
return True
def _get_formatted_dump(self) -> CommentedMap:
"""Gets the ruamel yaml formatted dump of the config
"""
raw_data = self.to_dict()
y = ruamel.yaml.YAML(typ='rt', pure=True)
data: CommentedMap = y.load(yaml.dump(raw_data))
return data
def feature_enabled(self, feature: FeatureFlag) -> bool:
"""Checks if the feature flag is enabled
Args:
feature (FeatureFlag): The feature to check
"""
if not self.internal and feature not in FeatureFlag.get_external_features():
# Only enable select features for external use
return False
if feature.value in self.feature_flags:
enabled = self.feature_flags.get(feature.value, False)
return bool(enabled)
return False
def __str__(self) -> str:
data = self._get_formatted_dump()
y = StringDumpYAML()
return y.dump(data)
def feature_enabled(feature: FeatureFlag) -> bool:
conf = MCLIConfig.load_config()
return conf.feature_enabled(feature=feature)