Source code for mcli.api.runs.api_update_run_metadata

""" Update the metadata of a run. """
from __future__ import annotations

import json
import logging
import math
from concurrent.futures import Future
from typing import Any, Dict, Optional, Union, overload

from typing_extensions import Literal

from mcli.api.engine.engine import get_return_response, run_singular_mapi_request
from mcli.api.model.run import Run
from mcli.config import MCLIConfig
from mcli.utils.utils_logging import WARN

logger = logging.getLogger(__name__)

__all__ = ['update_run_metadata']

QUERY_FUNCTION = 'updateRunMetadata'
VARIABLE_DATA_GET_RUNS = 'getRunsData'
VARIABLE_DATA_UPDATE_RUN_METADATA = 'updateRunMetadataData'
QUERY = f"""
mutation UpdateRunMetadata(${VARIABLE_DATA_GET_RUNS}: GetRunsInput!, ${VARIABLE_DATA_UPDATE_RUN_METADATA}: UpdateRunMetadataInput!) {{
  {QUERY_FUNCTION}({VARIABLE_DATA_GET_RUNS}: ${VARIABLE_DATA_GET_RUNS}, {VARIABLE_DATA_UPDATE_RUN_METADATA}: ${VARIABLE_DATA_UPDATE_RUN_METADATA}) {{
    id
    name
    createdByEmail
    status
    createdAt
    updatedAt
    reason
    priority
    maxRetries
    preemptible
    retryOnSystemFailure
    runType
    isDeleted
    resumptions {{
        clusterName
        cpus
        gpuType
        gpus
        nodes
        executionIndex
        startTime
        endTime
        status
    }}
    details {{
        metadata
    }}
  }}
}}"""


@overload
def update_run_metadata(run: Union[str, Run],
                        metadata: Dict[str, Any],
                        *,
                        timeout: Optional[float] = None,
                        future: Literal[False] = False,
                        protect: bool = False) -> Run:
    ...


@overload
def update_run_metadata(run: Union[str, Run],
                        metadata: Dict[str, Any],
                        *,
                        timeout: Optional[float] = None,
                        future: Literal[True] = True,
                        protect: bool = False) -> Future[Run]:
    ...


@overload
def update_run_metadata(run: Union[str, Run],
                        metadata: Dict[str, Any],
                        *,
                        timeout: Literal[None] = None,
                        future: bool = False,
                        protect: Literal[True] = True) -> Union[Run, Future[Run]]:
    ...


[docs]def update_run_metadata(run: Union[str, Run], metadata: Dict[str, Any], *, timeout: Optional[float] = 10, future: bool = False, protect: bool = False): """Update a run's metadata in the MosaicML platform. Args: run (``Optional[str | ``:class:`~mcli.api.model.run.Run` ``]``): A run or run name to update. Using :class:`~mcli.api.model.run.Run` objects is most efficient. See the note below. metadata (`Dict[str, Any]`): The metadata to update the run with. This will be merged with the existing metadata. Keys not specified in this dictionary will not be modified. timeout (``Optional[float]``): Time, in seconds, in which the call should complete. If the call takes too long, a :exc:`~concurrent.futures.TimeoutError` will be raised. If ``future`` is ``True``, this value will be ignored. future (``bool``): Return the output as a :class:`~concurrent.futures.Future`. If True, the call to :func:`update_run_metadata` will return immediately and the request will be processed in the background. This takes precedence over the ``timeout`` argument. To get the list of :class:`~mcli.api.model.run.Run` output, use ``return_value.result()`` with an optional ``timeout`` argument. protect (``bool``): If True, the call will be protected from SIGTERMs to allow it to complete reliably. Defaults to False. Raises: MAPIException: Raised if updating the requested run failed Returns: If future is False: Updated :class:`~mcli.api.model.run.Run` object Otherwise: A :class:`~concurrent.futures.Future` for the list """ valid_metadata = validate_metadata(metadata) variables = { VARIABLE_DATA_GET_RUNS: { 'filters': { 'name': { 'in': [run.name if isinstance(run, Run) else run] }, } }, VARIABLE_DATA_UPDATE_RUN_METADATA: { 'metadata': valid_metadata }, } cfg = MCLIConfig.load_config() cfg.update_entity(variables[VARIABLE_DATA_GET_RUNS]) response = run_singular_mapi_request( query=QUERY, query_function=QUERY_FUNCTION, return_model_type=Run, variables=variables, protect=protect, ) return get_return_response(response, future=future, timeout=timeout)
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: """Validate a metadata dictionary to ensure it can be serialized to JSON. Args: metadata (`Dict[str, Any]`): The metadata to validate. Raises: TypeError: Raised if keys in the metadata cannot be serialized to JSON. Returns: A validate metadata dictionary. """ valid_metadata = {} invalid_keys = [] for key, value in metadata.items(): # Serialize metadata values if possible, else ignore them serialized_value, is_serializable = serialize_value(value) if is_serializable: if isinstance(value, float) and (math.isnan(value) or math.isinf(value)): value = serialized_value valid_metadata[key] = value else: invalid_keys.append(key) if invalid_keys: # pylint: disable=deprecated-method logger.warn(f"{WARN} Metadata value for key '{invalid_keys}' is not JSON serializable. Ignoring.") return valid_metadata def serialize_value(value: Any) -> tuple[Any, bool]: """Determine if a value is JSON serializable and serialize it if possible.""" try: return json.dumps(value), True except TypeError: return None, False