Skip to content

Commit

Permalink
refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Oct 17, 2023
1 parent 96598e5 commit 12b7914
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 319 deletions.
11 changes: 8 additions & 3 deletions api/core/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser


Expand Down Expand Up @@ -156,9 +157,11 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()

# get llm prompt
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
Expand All @@ -167,7 +170,7 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App
memory=memory
)
else:
prompt_messages = model_instance.get_advanced_prompt(
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
Expand Down Expand Up @@ -238,8 +241,10 @@ def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_
if max_tokens is None:
max_tokens = 0

prompt_transform = PromptTransform()

# get prompt without memory and context
prompt_messages, _ = model_instance.get_prompt(
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
Expand Down
322 changes: 6 additions & 316 deletions api/core/model_providers/models/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
import json
import os
import re
import time
from abc import abstractmethod
from typing import List, Optional, Any, Union, Tuple
from typing import List, Optional, Any, Union
import decimal
import logging

from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from langchain.schema import LLMResult, BaseMessage, ChatGeneration

from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules, AppMode
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging

from extensions.ext_database import db

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -320,308 +310,8 @@ def add_callbacks(self, callbacks: Callbacks):
def support_streaming(self):
return False

def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops

def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

model_mode = app_model_config.model_dict['mode']

app_mode_enum = AppMode(app_mode)
model_mode_enum = ModelMode(model_mode)

prompt_messages = []

if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
elif model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)

return prompt_messages

def _set_context_variable(self, context, prompt_template, prompt_inputs):
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''

def _set_query_variable(self, query, prompt_template, prompt_inputs):
if '#query#' in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''

def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs):
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
)

rest_tokens = self._calculate_rest_token(tmp_human_message)

memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''

def _append_chat_histories(self, memory, prompt_messages):
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages)

memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)

def _calculate_rest_token(self, prompt_messages):
rest_tokens = 2000

if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages(prompt_messages))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)

return rest_tokens

def _format_prompt(self, prompt_template, prompt_inputs):
prompt = prompt_template.format(
prompt_inputs
)

prompt = re.sub(r'<\|.*?\|>', '', prompt)
return prompt

def _get_chat_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']

prompt_messages = []
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

self._set_query_variable(query, prompt_template, prompt_inputs)

self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))

return prompt_messages

def _get_chat_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []

for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

self._append_chat_histories(memory, prompt_messages)

prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))

return prompt_messages

def _get_completion_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']

prompt_messages = []
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))

return prompt_messages

def _get_completion_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []

for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

return prompt_messages

def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'

def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
{'context': context}
)

pre_prompt_content = ''
if pre_prompt:
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
prompt_inputs
)

prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content

query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'

if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)

rest_tokens = self._calculate_rest_token(tmp_human_message)

memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'

histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})

prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content

prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})

prompt += query_prompt_content

prompt = re.sub(r'<\|.*?\|>', '', prompt)

stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None

return prompt, stops

def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')

json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)

def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]

def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])

def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode:
model_mode = self.model_mode

Expand Down
Loading

0 comments on commit 12b7914

Please sign in to comment.