Source code for mcli.api.inference_deployments.api_predict_inference_deployment

""" Predict on an Inference Deployment """
from __future__ import annotations

import json
from http import HTTPStatus
from typing import Any, Callable, Dict, Generator, Optional, Union, cast

import requests
import validators
from requests import Response

from mcli import config
from mcli.api.exceptions import InferenceServerException
from mcli.api.inference_deployments import get_inference_deployment
from mcli.api.model.inference_deployment import InferenceDeployment

__all__ = ['predict']


[docs]def predict( deployment: Union[InferenceDeployment, str], inputs: Dict[str, Any], *, timeout: Optional[int] = 60, stream: bool = False, ) -> Union[Dict[str, Any], Generator[str, None, None]]: """Sends input to \'/predict\' endpoint of an inference deployment on the MosaicML platform. Runs prediction on input and returns output produced by the model. Arguments: deployment: The deployment to make a prediction with. Can be a InferenceDeployment object, the name of an deployment, or a string which is of the form https://<deployment dns>. input: Input data to run prediction on in the form of dictionary timeout: Time, in seconds, in which the call should complete. If the call takes too long, a TimeoutError will be raised. stream: If True, the response will be streamed and a generator will be returned. Streaming supports only a single input at a time. Raises: HTTPError: If sending the request to the endpoint fails MAPIException: If connecting to MAPI, raised when a MAPI communication error occurs """ validate_url = cast(Callable[[str], bool], validators.url) if isinstance(deployment, str) and not validate_url(deployment): # if a string is passed in that is not a url then lookup the deployment and get the name deployment = get_inference_deployment(deployment) conf = config.MCLIConfig.load_config() api_key = conf.api_key headers = { 'authorization': api_key, } base_url = deployment if isinstance(deployment, InferenceDeployment): base_url = f'https://{deployment.public_dns}' try: if stream: # we use an internal function to satisfy pyright def gen(): with requests.post(url=f'{base_url}/predict_stream', timeout=timeout, json=inputs, headers=headers, stream=True) as resp: for line in resp.iter_lines(): if line: loaded = json.loads(line) if loaded: yield loaded return gen() else: resp: Response = requests.post(url=f'{base_url}/predict', timeout=timeout, json=inputs, headers=headers) if resp.ok: try: return resp.json() except requests.JSONDecodeError as e: raise InferenceServerException.from_bad_response(resp) from e else: raise InferenceServerException.from_server_error_response(resp.content.decode().strip(), resp.status_code) except requests.exceptions.ReadTimeout as e: raise InferenceServerException.from_server_error_response(str(e), HTTPStatus.REQUEST_TIMEOUT) except requests.exceptions.ConnectionError as e: raise InferenceServerException.from_requests_error(e) from e