Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Jan 2, 2024
2 parents 768e0d7 + 57a5837 commit 293e0b4
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 57 deletions.
2 changes: 1 addition & 1 deletion api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def post(self):
tenant_id=current_user.current_tenant_id,
account=current_user,
config=model_config_dict,
mode=args['mode']
app_mode=args['mode']
)

app = App(
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/app/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def post(self, app_id):
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app.mode
app_mode=app.mode
)

new_app_model_config = AppModelConfig(
Expand Down
18 changes: 12 additions & 6 deletions api/core/application_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_d
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("messages", []):
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
Expand All @@ -328,12 +328,18 @@ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_d
advanced_completion_prompt_template = None
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}

if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}

advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
prompt=completion_prompt_config['prompt']['text'],
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
user=completion_prompt_config['conversation_histories_role']['user_prefix'],
assistant=completion_prompt_config['conversation_histories_role']['assistant_prefix']
)
**completion_prompt_template_params
)

properties['prompt_template'] = PromptTemplateEntity(
Expand Down
2 changes: 1 addition & 1 deletion api/core/entities/application_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class RolePrefixEntity(BaseModel):
assistant: str

prompt: str
role_prefix: RolePrefixEntity
role_prefix: Optional[RolePrefixEntity] = None


class PromptTemplateEntity(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from core.model_runtime.model_providers import google
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

import logging
logger = logging.getLogger(__name__)

class GoogleLargeLanguageModel(LargeLanguageModel):

def _invoke(self, model: str, credentials: dict,
Expand Down Expand Up @@ -198,14 +201,14 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon
index = -1
for chunk in response:
content = chunk.text

index += 1

assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)

if not response._done:

# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
Expand All @@ -216,6 +219,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon
)
)
else:

# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,89 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
requests.exceptions.ReadTimeout # Timeout
]
}
<<<<<<< Updated upstream
=======

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
model_type = ModelType.LLM if credentials.get('__model_type') == 'llm' else ModelType.TEXT_EMBEDDING

entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=model_type,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size', 16000),
ModelPropertyKey.MAX_CHUNKS: credentials.get('max_chunks', 1),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
default=float(credentials.get('temperature', 1)),
min=0,
max=2
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
default=float(credentials.get('top_p', 1)),
min=0,
max=1
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(credentials.get('top_k', 1)),
min=1,
max=100
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('frequency_penalty', 0)),
min=-2,
max=2
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
label=I18nObject(en_US="PRESENCE Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('PRESENCE_penalty', 0)),
min=-2,
max=2
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens"),
type=ParameterType.INT,
default=1024,
min=1,
max=int(credentials.get('max_tokens_to_sample', 4096)),
)
],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
output=Decimal(credentials.get('output_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)

if model_type == ModelType.LLM:
if credentials['mode'] == 'chat':
entity.model_properties['mode'] = LLMMode.CHAT
elif credentials['mode'] == 'completion':
entity.model_properties['mode'] = LLMMode.COMPLETION
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")

return entity
>>>>>>> Stashed changes
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

from typing import Optional, Generator, Union, List, cast

from sympy import comp

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper

from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
Expand Down Expand Up @@ -81,18 +83,35 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
try:
headers = {
'Authorization': f'Bearer {credentials["api_key"]}',
'Content-Type': 'application/json'
}

api_key = credentials.get('api_key')
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

endpoint_url = credentials['endpoint_url']

# prepare the payload for a simple ping to the model
data = {
'model': model,
'prompt': 'ping',
'max_tokens': 5
}

completion_type = LLMMode.value_of(credentials['mode'])

if completion_type is LLMMode.CHAT:
data['messages'] = [
{
"role": "user",
"content": "ping"
},
]
elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping'
else:
raise ValueError("Unsupported completion type for model configuration.")

# send a post request to validate the credentials
response = requests.post(
endpoint_url,
Expand Down Expand Up @@ -198,22 +217,30 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM
:return: full response or stream response chunk generator result
"""
headers = {
'Authorization': f'Bearer {credentials["api_key"]}',
'Content-Type': 'application/json'
}

endpoint_url = credentials["endpoint_url"]

model_config = self._extract_model_config(model, credentials)
api_key = credentials.get('api_key')
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

endpoint_url = credentials["endpoint_url"]

data = {
"messages": [self._convert_prompt_message_to_dict(m) for m in prompt_messages],
"model": model,
"stream": stream,
**model_parameters,
**model_config,
**model_parameters
}

completion_type = LLMMode.value_of(credentials['mode'])

if completion_type is LLMMode.CHAT:
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION:
data['prompt'] = prompt_messages[0].content
else:
raise ValueError("Unsupported completion type for model configuration.")

# annotate tools with names, descriptions, etc.
formatted_tools = []
if tools:
Expand All @@ -238,30 +265,17 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM
stream=stream
)

# Debug: Print request headers and json data
logger.debug(f"Request headers: {headers}")
logger.debug(f"Request JSON data: {data}")

if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")

if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)

return self._handle_generate_response(model, credentials, response, prompt_messages)

def _extract_model_config(self, model: str, credentials: dict) -> dict:
"""
extract customized LLM model configurations via credentials
"""
model_mode = self.get_model_mode(model)

config = {}

if model_mode == ModelType.LLM:
config['temperature'] = credentials.get('temperature')
config['top_p'] = credentials.get('top_p')
config['top_k'] = credentials.get('top_k')
config['frequency_penalty'] = credentials.get('frequency_penalty')
config['max_tokens_to_sample'] = credentials.get('max_tokens_to_sample')

return config

def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
Expand Down Expand Up @@ -303,7 +317,7 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f

try:
chunk_json = json.loads(decoded_chunk)
# stream ended by
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
Expand Down Expand Up @@ -366,12 +380,26 @@ def _handle_generate_response(self, model: str, credentials: dict, response: req
prompt_messages: list[PromptMessage]) -> LLMResult:

response_json = response.json()
assistant_message = AssistantPromptMessage(content=response_json['choices'][0]['message']['content'])
tool_calls = response_json['choices'][0]['message'].get('tool_calls', None)

completion_type = LLMMode.value_of(credentials['mode'])

output = response_json['choices'][0]

response_content = ''
tool_calls = None

if completion_type is LLMMode.CHAT:
response_content = output.get('message', {})['content']
tool_calls = output.get('message', {}).get('tool_calls')

elif completion_type is LLMMode.COMPLETION:
response_content = output['text']

assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])

if tool_calls:
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)

usage = response_json.get("usage")
if usage:
# transform usage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API endpoint URL
en_US: Enter your API endpoint URL
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
Expand Down
Loading

0 comments on commit 293e0b4

Please sign in to comment.