Skip to content

Commit

Permalink
refactor replicate.
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Jan 1, 2024
1 parent 120ecb5 commit de00184
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 96 deletions.
15 changes: 15 additions & 0 deletions api/core/model_runtime/model_providers/replicate/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from replicate.exceptions import ReplicateError, ModelError

from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError


class _CommonReplicate:

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
ReplicateError,
ModelError
]
}
138 changes: 81 additions & 57 deletions api/core/model_runtime/model_providers/replicate/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from typing import Optional, List, Union, Generator

from replicate import Client as ReplicateClient
from replicate.exceptions import ReplicateError, ModelError
from replicate.exceptions import ReplicateError
from replicate.prediction import Prediction

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMMode, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
PromptMessageRole
PromptMessageRole, UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import InvokeError, InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.replicate._common import _CommonReplicate


class ReplicateLargeLanguageModel(LargeLanguageModel):
class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):

def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:

version = credentials['model_version']

client = ReplicateClient(api_token=credentials['replicate_api_token'])
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)

Expand All @@ -40,39 +40,13 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes
)

if stream:
return self._handle_generate_stream_response(model, prediction, stop, prompt_messages)
return self._handle_generate_response(model, prediction, stop, prompt_messages)

@staticmethod
def _get_llm_usage():
usage = LLMUsage(
prompt_tokens=0,
prompt_unit_price=0,
prompt_price_unit=0,
prompt_price=0,
completion_tokens=0,
completion_unit_price=0,
completion_price_unit=0,
completion_price=0,
total_tokens=0,
total_price=0,
currency='USD',
latency=0,
)
return usage
return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages)
return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages)

def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
return 0
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)

def validate_credentials(self, model: str, credentials: dict) -> None:
if 'replicate_api_token' not in credentials:
Expand All @@ -88,7 +62,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
version = credentials['model_version']

try:
client = ReplicateClient(api_token=credentials['replicate_api_token'])
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)

Expand Down Expand Up @@ -128,15 +102,27 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option
def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
version = credentials['model_version']

client = ReplicateClient(api_token=credentials['replicate_api_token'])
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)

parameter_rules = []

for key, value in model_info_version.openapi_schema['components']['schemas']['Input']['properties'].items():
input_properties = sorted(
model_info_version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)

for key, value in input_properties:
if key not in ['system_prompt', 'prompt']:
param_type = cls._get_parameter_type(value['type'])
value_type = value.get('type')

if not value_type:
continue

param_type = cls._get_parameter_type(value_type)

rule = ParameterRule(
name=key,
Expand All @@ -156,20 +142,13 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict)

return parameter_rules

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
ReplicateError,
ModelError
]
}

def _handle_generate_stream_response(self, model: str,
def _handle_generate_stream_response(self,
model: str,
credentials: dict,
prediction: Prediction,
stop: list[str],
prompt_messages: list[PromptMessage]) -> Generator:
index = 0
index = -1
current_completion: str = ""
stop_condition_reached = False
for output in prediction.output_iterator():
Expand All @@ -187,18 +166,28 @@ def _handle_generate_stream_response(self, model: str,
if stop_condition_reached:
break

index += 1

assistant_prompt_message = AssistantPromptMessage(
content=output if output else ''
)

prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])

usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=output),
usage=self._get_llm_usage(),
message=assistant_prompt_message,
usage=usage,
),
)
index += 1

def _handle_generate_response(self, model: str, prediction: Prediction, stop: list[str],
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
prompt_messages: list[PromptMessage]) -> LLMResult:
current_completion: str = ""
stop_condition_reached = False
Expand All @@ -217,11 +206,19 @@ def _handle_generate_response(self, model: str, prediction: Prediction, stop: li
if stop_condition_reached:
break

usage = self._get_llm_usage()
assistant_prompt_message = AssistantPromptMessage(
content=current_completion
)

prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])

usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=current_completion),
message=assistant_prompt_message,
usage=usage,
)

Expand All @@ -237,3 +234,30 @@ def _get_parameter_type(cls, param_type: str) -> str:
return 'boolean'
elif param_type == 'string':
return 'string'

def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list

text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)

return text.rstrip()

@staticmethod
def _convert_one_message_to_text(message: PromptMessage) -> str:
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content

if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")

return message_text
Original file line number Diff line number Diff line change
@@ -1,37 +1,31 @@
import json
import time
from typing import Optional

from replicate import Client as ReplicateClient
from replicate.exceptions import ReplicateError, ModelError

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import InvokeError, InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.replicate._common import _CommonReplicate


class ReplicateEmbeddingModel(TextEmbeddingModel):
class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:

client = ReplicateClient(api_token=credentials['replicate_api_token'])
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
replicate_model_version = f'{model}:{credentials["model_version"]}'

text_input_key = self._get_text_input_key(model, credentials['model_version'], client)

embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
texts)
usage = EmbeddingUsage(
tokens=0,
total_tokens=0,
unit_price=0.0,
price_unit=0.0,
total_price=0.0,
currency='USD',
latency=0.0
)

tokens = self.get_num_tokens(model, credentials, texts)
usage = self._calc_response_usage(model, credentials, tokens)

return TextEmbeddingResult(
model=model,
Expand All @@ -40,15 +34,10 @@ def _invoke(self, model: str, credentials: dict, texts: list[str],
)

def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return 0
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens

def validate_credentials(self, model: str, credentials: dict) -> None:
if 'replicate_api_token' not in credentials:
Expand All @@ -58,7 +47,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')

try:
client = ReplicateClient(api_token=credentials['replicate_api_token'])
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
replicate_model_version = f'{model}:{credentials["model_version"]}'

text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
Expand All @@ -83,15 +72,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option
)
return entity

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
ReplicateError,
ModelError
]
}

@staticmethod
def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) -> str:
model_info = client.models.get(model)
Expand Down Expand Up @@ -135,3 +115,24 @@ def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_mo
return result
else:
raise ValueError(f'embeddings input key is invalid: {text_input_key}')

def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)

return usage
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ def test_get_num_tokens():
]
)

assert num_tokens == 0
assert num_tokens == 14
Loading

0 comments on commit de00184

Please sign in to comment.