Skip to content

Commit

Permalink
refactor advanced prompt core.
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Oct 15, 2023
1 parent 2feb16d commit 96598e5
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 84 deletions.
4 changes: 4 additions & 0 deletions api/core/model_providers/models/entity/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from langchain.load.serializable import Serializable
from pydantic import BaseModel

class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'


class ModelMode(enum.Enum):
COMPLETION = 'completion'
Expand Down
242 changes: 172 additions & 70 deletions api/core/model_providers/models/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules, AppMode
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
Expand Down Expand Up @@ -330,83 +330,191 @@ def get_prompt(self, mode: str,
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops

def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}

raw_prompt_list = []
app_mode_enum = AppMode(app_mode)
model_mode_enum = ModelMode(model_mode)

prompt_messages = []

if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
elif model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)

return prompt_messages

def _set_context_variable(self, context, prompt_template, prompt_inputs):
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''

def _set_query_variable(self, query, prompt_template, prompt_inputs):
if '#query#' in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''

def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs):
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
)

rest_tokens = self._calculate_rest_token(tmp_human_message)

memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''

def _append_chat_histories(self, memory, prompt_messages):
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages)

memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)

def _calculate_rest_token(self, prompt_messages):
rest_tokens = 2000

if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages(prompt_messages))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)

return rest_tokens

def _format_prompt(self, prompt_template, prompt_inputs):
prompt = prompt_template.format(
prompt_inputs
)

prompt = re.sub(r'<\|.*?\|>', '', prompt)
return prompt

def _get_chat_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']

prompt_messages = []
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

self._set_query_variable(query, prompt_template, prompt_inputs)

self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))

return prompt_messages

def _get_chat_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []

for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
raw_prompt = prompt_item['text']
prompt = ''

# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''

if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''

if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''

prompt = prompt_template.format(
prompt_inputs
)
self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

self._append_chat_histories(memory, prompt_messages)

if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))

return prompt_messages

def _get_completion_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']

prompt_messages = []
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))

return prompt_messages

def _get_completion_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []

for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

return prompt_messages

def prompt_file_name(self, mode: str) -> str:
Expand Down Expand Up @@ -452,13 +560,7 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict
}
)

if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
rest_tokens = self._calculate_rest_token(tmp_human_message)

memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
Expand Down
25 changes: 13 additions & 12 deletions api/services/advanced_prompt_template_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import copy

from core.model_providers.models.entity.model_params import AppMode, ModelMode
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT

Expand All @@ -22,15 +23,15 @@ def get_prompt(cls, args: dict) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT)

if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)

@classmethod
Expand All @@ -51,13 +52,13 @@ def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str)
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)

if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
11 changes: 9 additions & 2 deletions api/services/app_model_config_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode
from core.model_providers.models.entity.model_params import ModelType, ModelMode, AppMode
from models.account import Account
from services.dataset_service import DatasetService

Expand Down Expand Up @@ -418,7 +418,7 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")

if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']

Expand All @@ -427,3 +427,10 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:

if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'


if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']

if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")

0 comments on commit 96598e5

Please sign in to comment.