diff --git a/api/core/completion.py b/api/core/completion.py index a481ec39d30413..f9078617b6a64e 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -167,7 +167,8 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App inputs=inputs, query=query, context=agent_execute_result.output if agent_execute_result else None, - memory=memory + memory=memory, + model_instance=model_instance ) else: prompt_messages = prompt_transform.get_advanced_prompt( @@ -176,7 +177,8 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App inputs=inputs, query=query, context=agent_execute_result.output if agent_execute_result else None, - memory=memory + memory=memory, + model_instance=model_instance ) model_config = app_model_config.model_dict @@ -250,7 +252,8 @@ def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_ inputs=inputs, query=query, context=None, - memory=None + memory=None, + model_instance=model_instance ) prompt_tokens = model_instance.get_num_tokens(prompt_messages) diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 41724dd54bfae0..23f35fadb14332 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -310,6 +310,12 @@ def add_callbacks(self, callbacks: Callbacks): def support_streaming(self): return False + def prompt_file_name(self, mode: str) -> str: + if mode == 'completion': + return 'common_completion' + else: + return 'common_chat' + def _get_prompt_from_messages(self, messages: List[PromptMessage], model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]: if not model_mode: diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9cd1c7704d4cd2..c3024f51565049 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -6,6 +6,7 @@ from langchain.memory.chat_memory import BaseChatMemory from core.model_providers.models.entity.model_params import ModelMode, AppMode from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser @@ -15,10 +16,11 @@ def get_prompt(self, mode: str, pre_prompt: str, inputs: dict, query: str, context: Optional[str], - memory: Optional[BaseChatMemory]) -> \ + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> \ Tuple[List[PromptMessage], Optional[List[str]]]: - prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode)) - prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory) + prompt_rules = self._read_prompt_rules_from_file(model_instance.prompt_file_name(mode)) + prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance) return [PromptMessage(content=prompt)], stops def get_advanced_prompt(self, @@ -27,7 +29,8 @@ def get_advanced_prompt(self, inputs: dict, query: str, context: Optional[str], - memory: Optional[BaseChatMemory]) -> List[PromptMessage]: + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: model_mode = app_model_config.model_dict['mode'] @@ -38,9 +41,9 @@ def get_advanced_prompt(self, if app_mode_enum == AppMode.CHAT: if model_mode_enum == ModelMode.COMPLETION: - prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory) + prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance) elif model_mode_enum == ModelMode.CHAT: - prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory) + prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance) elif app_mode_enum == AppMode.COMPLETION: if model_mode_enum == ModelMode.CHAT: prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context) @@ -67,12 +70,6 @@ def _get_history_messages_list_from_memory(self, memory: BaseChatMemory, memory.return_messages = False return to_prompt_messages(external_context[memory_key]) - def prompt_file_name(self, mode: str) -> str: - if mode == 'completion': - return 'common_completion' - else: - return 'common_chat' - def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: # Get the absolute path of the subdirectory prompt_path = os.path.join( @@ -87,7 +84,8 @@ def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict, query: str, context: Optional[str], - memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]: + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> Tuple[str, Optional[list]]: context_prompt_content = '' if context and 'context_prompt' in prompt_rules: prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) @@ -121,7 +119,7 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict } ) - rest_tokens = self._calculate_rest_token(tmp_human_message) + rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance) memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human' memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' @@ -166,7 +164,7 @@ def _set_query_variable(self, query, prompt_template, prompt_inputs): else: prompt_inputs['#query#'] = '' - def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs): + def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance): if '#histories#' in prompt_template.variable_keys: if memory: tmp_human_message = PromptBuilder.to_human_message( @@ -174,7 +172,7 @@ def _set_histories_variable(self, memory, raw_prompt, conversation_histories_rol inputs={ '#histories#': '', **prompt_inputs } ) - rest_tokens = self._calculate_rest_token(tmp_human_message) + rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance) memory.human_prefix = conversation_histories_role['user_prefix'] memory.ai_prefix = conversation_histories_role['assistant_prefix'] @@ -183,22 +181,22 @@ def _set_histories_variable(self, memory, raw_prompt, conversation_histories_rol else: prompt_inputs['#histories#'] = '' - def _append_chat_histories(self, memory, prompt_messages): + def _append_chat_histories(self, memory, prompt_messages, model_instance): if memory: - rest_tokens = self._calculate_rest_token(prompt_messages) + rest_tokens = self._calculate_rest_token(prompt_messages, model_instance) memory.human_prefix = MessageType.USER.value memory.ai_prefix = MessageType.ASSISTANT.value histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) - def _calculate_rest_token(self, prompt_messages): + def _calculate_rest_token(self, prompt_messages, model_instance: BaseLLM): rest_tokens = 2000 - if self.model_rules.max_tokens.max: - curr_message_tokens = self.get_num_tokens(to_prompt_messages(prompt_messages)) - max_tokens = self.model_kwargs.max_tokens - rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens + if model_instance.model_rules.max_tokens.max: + curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages)) + max_tokens = model_instance.model_kwargs.max_tokens + rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens @@ -216,7 +214,8 @@ def _get_chat_app_completion_model_prompt_messages(self, inputs: dict, query: str, context: Optional[str], - memory: Optional[BaseChatMemory]) -> List[PromptMessage]: + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text'] conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role'] @@ -231,7 +230,7 @@ def _get_chat_app_completion_model_prompt_messages(self, self._set_query_variable(query, prompt_template, prompt_inputs) - self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs) + self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance) prompt = self._format_prompt(prompt_template, prompt_inputs) @@ -244,7 +243,8 @@ def _get_chat_app_chat_model_prompt_messages(self, inputs: dict, query: str, context: Optional[str], - memory: Optional[BaseChatMemory]) -> List[PromptMessage]: + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] prompt_messages = [] @@ -262,7 +262,7 @@ def _get_chat_app_chat_model_prompt_messages(self, prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) - self._append_chat_histories(memory, prompt_messages) + self._append_chat_histories(memory, prompt_messages, model_instance) prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))