Source code for clarifai.client.model

import os
import time
from typing import Dict, List

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.resources_pb2 import Input
from clarifai_grpc.grpc.api.status import status_code_pb2

from clarifai.client.base import BaseClient
from clarifai.client.lister import Lister
from clarifai.errors import UserError
from clarifai.urls.helper import ClarifaiUrlHelper
from clarifai.utils.logging import get_logger
from clarifai.utils.misc import BackoffIterator


[docs]class Model(Lister, BaseClient): """Model is a class that provides access to Clarifai API endpoints related to Model information."""
[docs] def __init__(self, url_init: str = "", model_id: str = "", model_version: Dict = {'id': ""}, output_config: Dict = {'min_value': 0}, **kwargs): """Initializes a Model object. Args: url_init (str): The URL to initialize the model object. model_id (str): The Model ID to interact with. model_version (dict): The Model Version to interact with. output_config (dict): The output config to interact with. min_value (float): The minimum value of the prediction confidence to filter. max_concepts (int): The maximum number of concepts to return. select_concepts (list[Concept]): The concepts to select. sample_ms (int): The number of milliseconds to sample. **kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper. """ if url_init != "" and model_id != "": raise UserError("You can only specify one of url_init or model_id.") if url_init == "" and model_id == "": raise UserError("You must specify one of url_init or model_id.") if url_init != "": user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url( url_init) model_version = {'id': model_version_id} kwargs = {'user_id': user_id, 'app_id': app_id} self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version, 'output_info': {'output_config': output_config}} self.model_info = resources_pb2.Model(**self.kwargs) self.logger = get_logger(logger_level="INFO") BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id) Lister.__init__(self)
[docs] def predict(self, inputs: List[Input]): """Predicts the model based on the given inputs. Args: inputs (list[Input]): The inputs to predict, must be less than 128. """ if len(inputs) > 128: raise UserError("Too many inputs. Max is 128.") # TODO Use Chunker for inputs len > 128 request = service_pb2.PostModelOutputsRequest( user_app_id=self.user_app_id, model_id=self.id, version_id=self.model_version.id, inputs=inputs, model=self.model_info) start_time = time.time() backoff_iterator = BackoffIterator() while True: response = self._grpc_request(self.STUB.PostModelOutputs, request) if response.outputs and \ response.outputs[0].status.code == status_code_pb2.MODEL_DEPLOYING and \ time.time() - start_time < 60 * 10: # 10 minutes self.logger.info(f"{self.id} model is still deploying, please wait...") time.sleep(next(backoff_iterator)) continue if response.status.code != status_code_pb2.SUCCESS: raise Exception(f"Model Predict failed with response {response.status!r}") else: break return response
[docs] def predict_by_filepath(self, filepath: str, input_type: str): """Predicts the model based on the given filepath. Args: filepath (str): The filepath to predict. input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio. Example: >>> from clarifai.client.model import Model >>> model = Model("model_url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition or >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id') >>> model_prediction = model.predict_by_filepath('/path/to/image.jpg', 'image') >>> model_prediction = model.predict_by_filepath('/path/to/text.txt', 'text') """ if input_type not in ['image', 'text', 'video', 'audio']: raise UserError('Invalid input type it should be image, text, video or audio.') if not os.path.isfile(filepath): raise UserError('Invalid filepath.') with open(filepath, "rb") as f: file_bytes = f.read() return self.predict_by_bytes(file_bytes, input_type)
[docs] def predict_by_bytes(self, input_bytes: bytes, input_type: str): """Predicts the model based on the given bytes. Args: input_bytes (bytes): File Bytes to predict on. input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio'. Example: >>> from clarifai.client.model import Model >>> model = Model("https://clarifai.com/anthropic/completion/models/claude-v2") >>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI', 'text') """ if input_type not in {'image', 'text', 'video', 'audio'}: raise UserError('Invalid input type it should be image, text, video or audio.') if not isinstance(input_bytes, bytes): raise UserError('Invalid bytes.') # TODO will obtain proto from input class if input_type == "image": input_proto = resources_pb2.Input( data=resources_pb2.Data(image=resources_pb2.Image(base64=input_bytes))) elif input_type == "text": input_proto = resources_pb2.Input( data=resources_pb2.Data(text=resources_pb2.Text(raw=input_bytes))) elif input_type == "video": input_proto = resources_pb2.Input( data=resources_pb2.Data(video=resources_pb2.Video(base64=input_bytes))) elif input_type == "audio": input_proto = resources_pb2.Input( data=resources_pb2.Data(audio=resources_pb2.Audio(base64=input_bytes))) return self.predict(inputs=[input_proto])
[docs] def predict_by_url(self, url: str, input_type: str): """Predicts the model based on the given URL. Args: url (str): The URL to predict. input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio. Example: >>> from clarifai.client.model import Model >>> model = Model("model_url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition or >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id') >>> model_prediction = model.predict_by_url('url', 'image') """ if input_type not in {'image', 'text', 'video', 'audio'}: raise UserError('Invalid input type it should be image, text, video or audio.') # TODO will be obtain proto from input class if input_type == "image": input_proto = resources_pb2.Input( data=resources_pb2.Data(image=resources_pb2.Image(url=url))) elif input_type == "text": input_proto = resources_pb2.Input(data=resources_pb2.Data(text=resources_pb2.Text(url=url))) elif input_type == "video": input_proto = resources_pb2.Input( data=resources_pb2.Data(video=resources_pb2.Video(url=url))) elif input_type == "audio": input_proto = resources_pb2.Input( data=resources_pb2.Data(audio=resources_pb2.Audio(url=url))) return self.predict(inputs=[input_proto])
[docs] def list_versions(self) -> List['Model']: """Lists all the versions for the model. Returns: List[Model]: A list of Model objects for the versions of the model. Example: >>> from clarifai.client.model import Model >>> model = Model("model_url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition or >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id') >>> all_model_versions = model.list_versions() """ request_data = dict( user_app_id=self.user_app_id, model_id=self.id, per_page=self.default_page_size, ) all_model_versions_info = list( self.list_all_pages_generator(self.STUB.ListModelVersions, service_pb2.ListModelVersionsRequest, request_data)) for model_version_info in all_model_versions_info: model_version_info['id'] = model_version_info['model_version_id'] del model_version_info['model_version_id'] return [ Model(model_id=self.id, **dict(self.kwargs, model_version=model_version_info)) for model_version_info in all_model_versions_info ]
def __getattr__(self, name): return getattr(self.model_info, name) def __str__(self): init_params = [param for param in self.kwargs.keys()] attribute_strings = [ f"{param}={getattr(self.model_info, param)}" for param in init_params if hasattr(self.model_info, param) ] return f"Model Details: \n{', '.join(attribute_strings)}\n"