diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py new file mode 100644 index 00000000000000..1a14d29cea06e2 --- /dev/null +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -0,0 +1,15 @@ +from replicate.exceptions import ReplicateError, ModelError + +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError + + +class _CommonReplicate: + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeBadRequestError: [ + ReplicateError, + ModelError + ] + } diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 0a6f1cc2a58a5e..8691833f0fa3c8 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,20 +1,20 @@ from typing import Optional, List, Union, Generator from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError, ModelError +from replicate.exceptions import ReplicateError from replicate.prediction import Prediction from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMMode, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \ - PromptMessageRole + PromptMessageRole, UserPromptMessage, SystemPromptMessage from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType -from core.model_runtime.errors.invoke import InvokeError, InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.replicate._common import _CommonReplicate -class ReplicateLargeLanguageModel(LargeLanguageModel): +class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, @@ -22,7 +22,7 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes version = credentials['model_version'] - client = ReplicateClient(api_token=credentials['replicate_api_token']) + client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) model_info = client.models.get(model) model_info_version = model_info.versions.get(version) @@ -40,39 +40,13 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes ) if stream: - return self._handle_generate_stream_response(model, prediction, stop, prompt_messages) - return self._handle_generate_response(model, prediction, stop, prompt_messages) - - @staticmethod - def _get_llm_usage(): - usage = LLMUsage( - prompt_tokens=0, - prompt_unit_price=0, - prompt_price_unit=0, - prompt_price=0, - completion_tokens=0, - completion_unit_price=0, - completion_price_unit=0, - completion_price=0, - total_tokens=0, - total_price=0, - currency='USD', - latency=0, - ) - return usage + return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) + return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return: - """ - return 0 + prompt = self._convert_messages_to_prompt(prompt_messages) + return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: if 'replicate_api_token' not in credentials: @@ -88,7 +62,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: version = credentials['model_version'] try: - client = ReplicateClient(api_token=credentials['replicate_api_token']) + client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) model_info = client.models.get(model) model_info_version = model_info.versions.get(version) @@ -128,15 +102,27 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: version = credentials['model_version'] - client = ReplicateClient(api_token=credentials['replicate_api_token']) + client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) model_info = client.models.get(model) model_info_version = model_info.versions.get(version) parameter_rules = [] - for key, value in model_info_version.openapi_schema['components']['schemas']['Input']['properties'].items(): + input_properties = sorted( + model_info_version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ].items(), + key=lambda item: item[1].get("x-order", 0), + ) + + for key, value in input_properties: if key not in ['system_prompt', 'prompt']: - param_type = cls._get_parameter_type(value['type']) + value_type = value.get('type') + + if not value_type: + continue + + param_type = cls._get_parameter_type(value_type) rule = ParameterRule( name=key, @@ -156,20 +142,13 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) return parameter_rules - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } - - def _handle_generate_stream_response(self, model: str, + def _handle_generate_stream_response(self, + model: str, + credentials: dict, prediction: Prediction, stop: list[str], prompt_messages: list[PromptMessage]) -> Generator: - index = 0 + index = -1 current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -187,18 +166,28 @@ def _handle_generate_stream_response(self, model: str, if stop_condition_reached: break + index += 1 + + assistant_prompt_message = AssistantPromptMessage( + content=output if output else '' + ) + + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=output), - usage=self._get_llm_usage(), + message=assistant_prompt_message, + usage=usage, ), ) - index += 1 - def _handle_generate_response(self, model: str, prediction: Prediction, stop: list[str], + def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], prompt_messages: list[PromptMessage]) -> LLMResult: current_completion: str = "" stop_condition_reached = False @@ -217,11 +206,19 @@ def _handle_generate_response(self, model: str, prediction: Prediction, stop: li if stop_condition_reached: break - usage = self._get_llm_usage() + assistant_prompt_message = AssistantPromptMessage( + content=current_completion + ) + + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=current_completion), + message=assistant_prompt_message, usage=usage, ) @@ -237,3 +234,30 @@ def _get_parameter_type(cls, param_type: str) -> str: return 'boolean' elif param_type == 'string': return 'string' + + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: + messages = messages.copy() # don't mutate the original list + + text = "".join( + self._convert_one_message_to_text(message) + for message in messages + ) + + return text.rstrip() + + @staticmethod + def _convert_one_message_to_text(message: PromptMessage) -> str: + human_prompt = "\n\nHuman:" + ai_prompt = "\n\nAssistant:" + content = message.content + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage): + message_text = content + else: + raise ValueError(f"Got unknown type {message}") + + return message_text diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index c7f1c708cac45b..3d6fdc74a77db2 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -1,37 +1,31 @@ import json +import time from typing import Optional from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError, ModelError from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage -from core.model_runtime.errors.invoke import InvokeError, InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.replicate._common import _CommonReplicate -class ReplicateEmbeddingModel(TextEmbeddingModel): +class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): def _invoke(self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult: - client = ReplicateClient(api_token=credentials['replicate_api_token']) + client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) replicate_model_version = f'{model}:{credentials["model_version"]}' text_input_key = self._get_text_input_key(model, credentials['model_version'], client) embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) - usage = EmbeddingUsage( - tokens=0, - total_tokens=0, - unit_price=0.0, - price_unit=0.0, - total_price=0.0, - currency='USD', - latency=0.0 - ) + + tokens = self.get_num_tokens(model, credentials, texts) + usage = self._calc_response_usage(model, credentials, tokens) return TextEmbeddingResult( model=model, @@ -40,15 +34,10 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - return 0 + num_tokens = 0 + for text in texts: + num_tokens += self._get_num_tokens_by_gpt2(text) + return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: if 'replicate_api_token' not in credentials: @@ -58,7 +47,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError('Replicate Model Version must be provided.') try: - client = ReplicateClient(api_token=credentials['replicate_api_token']) + client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) replicate_model_version = f'{model}:{credentials["model_version"]}' text_input_key = self._get_text_input_key(model, credentials['model_version'], client) @@ -83,15 +72,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option ) return entity - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } - @staticmethod def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) -> str: model_info = client.models.get(model) @@ -135,3 +115,24 @@ def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_mo return result else: raise ValueError(f'embeddings input key is invalid: {text_input_key}') + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index 33d6b0654ce378..61a4ab280742a7 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -116,4 +116,4 @@ def test_get_num_tokens(): ] ) - assert num_tokens == 0 + assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 1cff98b76619c2..5708ec9e5a219e 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -67,7 +67,7 @@ def test_invoke_model_one(): assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 0 + assert result.usage.total_tokens == 2 def test_invoke_model_two(): @@ -88,7 +88,7 @@ def test_invoke_model_two(): assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 0 + assert result.usage.total_tokens == 2 def test_invoke_model_three(): @@ -109,7 +109,7 @@ def test_invoke_model_three(): assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 0 + assert result.usage.total_tokens == 2 def test_invoke_model_four(): @@ -130,7 +130,7 @@ def test_invoke_model_four(): assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 0 + assert result.usage.total_tokens == 2 def test_get_num_tokens(): @@ -148,4 +148,4 @@ def test_get_num_tokens(): ] ) - assert num_tokens == 0 + assert num_tokens == 2