Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Dec 28, 2023
2 parents 3420454 + d033609 commit a833026
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 13 deletions.
4 changes: 2 additions & 2 deletions api/core/model_runtime/model_providers/localai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
], model_parameters={
'max_tokens': 10,
}, stop=[])
except APIConnectionError:
raise CredentialsValidateFailedError('Invalid credentials {credentials}')
except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
completion_model = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
required: true
default: 4800
default: 256
min: 1
max: 4800
- name: presence_penalty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
default: 1024
min: 1
max: 8000
- name: presence_penalty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
required: true
default: 11200
default: 1024
min: 1
max: 11200
- name: presence_penalty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
required: true
default: 4800
default: 256
min: 1
max: 4800
- name: presence_penalty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator

if is_end:
usage = data['usage']
finish_reason = data['finish_reason']
finish_reason = data.get('finish_reason', None)
message = ErnieMessage(content=result, role='assistant')
message.usage = {
'prompt_tokens': usage['prompt_tokens'],
Expand Down
13 changes: 7 additions & 6 deletions api/core/prompt/prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage, \
TextPromptMessageContent, PromptMessageRole
TextPromptMessageContent, PromptMessageRole, AssistantPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_builder import PromptBuilder
Expand Down Expand Up @@ -244,7 +244,8 @@ def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict,

prompt = re.sub(r'<\|.*?\|>', '', prompt)

prompt_messages.append(SystemPromptMessage(content=prompt))
if prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))

self._append_chat_histories(
memory=memory,
Expand Down Expand Up @@ -477,10 +478,10 @@ def _get_chat_app_chat_model_prompt_messages(self,

if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.SYSTEM:
elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(SystemPromptMessage(content=prompt))
prompt_messages.append(AssistantPromptMessage(content=prompt))

self._append_chat_histories(memory, prompt_messages, model_config)

Expand Down Expand Up @@ -535,10 +536,10 @@ def _get_completion_app_chat_model_prompt_messages(self,

if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.SYSTEM:
elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(SystemPromptMessage(content=prompt))
prompt_messages.append(AssistantPromptMessage(content=prompt))

for prompt_message in prompt_messages[::-1]:
if prompt_message.role == PromptMessageRole.USER:
Expand Down
3 changes: 3 additions & 0 deletions api/services/model_provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []

if model.deprecated:
continue

provider_models[model.provider.provider].append(model)

# convert to ProviderWithModelsResponse list
Expand Down

0 comments on commit a833026

Please sign in to comment.