Skip to content

Commit

Permalink
refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Oct 18, 2023
1 parent 96598e5 commit 5832488
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 368 deletions.
44 changes: 32 additions & 12 deletions api/core/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser


Expand Down Expand Up @@ -156,24 +157,28 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()

# get llm prompt
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = model_instance.get_advanced_prompt(
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
memory=memory,
model_instance=model_instance
)

model_config = app_model_config.model_dict
Expand Down Expand Up @@ -238,15 +243,30 @@ def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_
if max_tokens is None:
max_tokens = 0

prompt_transform = PromptTransform()
prompt_messages = []

# get prompt without memory and context
prompt_messages, _ = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None
)
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)

prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
Expand Down
4 changes: 0 additions & 4 deletions api/core/model_providers/models/entity/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from langchain.load.serializable import Serializable
from pydantic import BaseModel

class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'


class ModelMode(enum.Enum):
COMPLETION = 'completion'
Expand Down
6 changes: 0 additions & 6 deletions api/core/model_providers/models/llm/baichuan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ def _run(self, messages: List[PromptMessage],
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
Expand Down
Loading

0 comments on commit 5832488

Please sign in to comment.