From 695841a3cf865840bdd5699ff7c406bc481919a2 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 13 Oct 2023 16:47:01 +0800 Subject: [PATCH] Feat/advanced prompt enhancement (#1340) --- .../console/app/advanced_prompt_template.py | 3 +- api/core/prompt/advanced_prompt_templates.py | 12 ++++-- .../advanced_prompt_template_service.py | 39 +++++++++++-------- api/services/app_model_config_service.py | 3 ++ 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index ce47e9e4d8793b..c92f0570dc4f94 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -20,7 +20,6 @@ def get(self): parser.add_argument('model_name', type=str, required=True, location='args') args = parser.parse_args() - service = AdvancedPromptTemplateService() - return service.get_prompt(args) + return AdvancedPromptTemplateService.get_prompt(args) api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/advanced_prompt_templates.py index c5eee005b6faab..da40534d99485b 100644 --- a/api/core/prompt/advanced_prompt_templates.py +++ b/api/core/prompt/advanced_prompt_templates.py @@ -11,7 +11,8 @@ "user_prefix": "Human", "assistant_prefix": "Assistant" } - } + }, + "stop": ["Human:"] } CHAT_APP_CHAT_PROMPT_CONFIG = { @@ -37,7 +38,8 @@ "prompt": { "text": "{{#pre_prompt#}}" } - } + }, + "stop": ["Human:"] } BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { @@ -49,7 +51,8 @@ "user_prefix": "用户", "assistant_prefix": "助手" } - } + }, + "stop": ["用户:"] } BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { @@ -75,5 +78,6 @@ "prompt": { "text": "{{#pre_prompt#}}" } - } + }, + "stop": ["用户:"] } diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 3ef2b6059e5ff4..58e2c658fb9c2e 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -6,51 +6,58 @@ class AdvancedPromptTemplateService: - def get_prompt(self, args: dict) -> dict: + @classmethod + def get_prompt(cls, args: dict) -> dict: app_mode = args['app_mode'] model_mode = args['model_mode'] model_name = args['model_name'] has_context = args['has_context'] if 'baichuan' in model_name: - return self.get_baichuan_prompt(app_mode, model_mode, has_context) + return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: - return self.get_common_prompt(app_mode, model_mode, has_context) + return cls.get_common_prompt(app_mode, model_mode, has_context) + + @classmethod + def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + context_prompt = copy.deepcopy(CONTEXT) - def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict: if app_mode == 'chat': if model_mode == 'completion': - return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT) + return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) elif model_mode == 'chat': - return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT) + return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) elif app_mode == 'completion': if model_mode == 'completion': - return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT) + return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) elif model_mode == 'chat': - return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT) + return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict: + @classmethod + def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: if has_context == 'true': prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] return prompt_template - - def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict: + @classmethod + def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: if has_context == 'true': prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] return prompt_template + @classmethod + def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict: if app_mode == 'chat': if model_mode == 'completion': - return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) + return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) elif model_mode == 'chat': - return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) + return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) elif app_mode == 'completion': if model_mode == 'completion': - return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) + return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) elif model_mode == 'chat': - return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) \ No newline at end of file + return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 4acb2f346fbfc2..4c49b43b887c6d 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -56,6 +56,9 @@ def validate_model_completion_params(cp: dict, model_name: str) -> dict: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") + + if len(cp["stop"]) > 4: + raise ValueError("stop sequences must be less than 4") # Filter out extra parameters filtered_cp = {