Skip to content

Commit

Permalink
update.
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Oct 18, 2023
1 parent 12b7914 commit e7a4a88
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 30 deletions.
9 changes: 6 additions & 3 deletions api/core/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
Expand All @@ -176,7 +177,8 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
memory=memory,
model_instance=model_instance
)

model_config = app_model_config.model_dict
Expand Down Expand Up @@ -250,7 +252,8 @@ def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_
inputs=inputs,
query=query,
context=None,
memory=None
memory=None,
model_instance=model_instance
)

prompt_tokens = model_instance.get_num_tokens(prompt_messages)
Expand Down
6 changes: 6 additions & 0 deletions api/core/model_providers/models/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ def add_callbacks(self, callbacks: Callbacks):
def support_streaming(self):
return False

def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'

def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode:
Expand Down
54 changes: 27 additions & 27 deletions api/core/prompt/prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain.memory.chat_memory import BaseChatMemory
from core.model_providers.models.entity.model_params import ModelMode, AppMode
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser

Expand All @@ -15,10 +16,11 @@ def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
prompt_rules = self._read_prompt_rules_from_file(model_instance.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
return [PromptMessage(content=prompt)], stops

def get_advanced_prompt(self,
Expand All @@ -27,7 +29,8 @@ def get_advanced_prompt(self,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:

model_mode = app_model_config.model_dict['mode']

Expand All @@ -38,9 +41,9 @@ def get_advanced_prompt(self,

if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory)
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory)
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
Expand All @@ -67,12 +70,6 @@ def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])

def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'

def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
Expand All @@ -87,7 +84,8 @@ def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
Expand Down Expand Up @@ -121,7 +119,7 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict
}
)

rest_tokens = self._calculate_rest_token(tmp_human_message)
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)

memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
Expand Down Expand Up @@ -166,15 +164,15 @@ def _set_query_variable(self, query, prompt_template, prompt_inputs):
else:
prompt_inputs['#query#'] = ''

def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs):
def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance):
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
)

rest_tokens = self._calculate_rest_token(tmp_human_message)
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)

memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
Expand All @@ -183,22 +181,22 @@ def _set_histories_variable(self, memory, raw_prompt, conversation_histories_rol
else:
prompt_inputs['#histories#'] = ''

def _append_chat_histories(self, memory, prompt_messages):
def _append_chat_histories(self, memory, prompt_messages, model_instance):
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages)
rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)

memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)

def _calculate_rest_token(self, prompt_messages):
def _calculate_rest_token(self, prompt_messages, model_instance: BaseLLM):
rest_tokens = 2000

if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages(prompt_messages))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
if model_instance.model_rules.max_tokens.max:
curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages))
max_tokens = model_instance.model_kwargs.max_tokens
rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)

return rest_tokens
Expand All @@ -216,7 +214,8 @@ def _get_chat_app_completion_model_prompt_messages(self,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:

raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
Expand All @@ -231,7 +230,7 @@ def _get_chat_app_completion_model_prompt_messages(self,

self._set_query_variable(query, prompt_template, prompt_inputs)

self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs)
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)

prompt = self._format_prompt(prompt_template, prompt_inputs)

Expand All @@ -244,7 +243,8 @@ def _get_chat_app_chat_model_prompt_messages(self,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []
Expand All @@ -262,7 +262,7 @@ def _get_chat_app_chat_model_prompt_messages(self,

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

self._append_chat_histories(memory, prompt_messages)
self._append_chat_histories(memory, prompt_messages, model_instance)

prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))

Expand Down

0 comments on commit e7a4a88

Please sign in to comment.