From 714aa0219d358aec4749a8c870e79101fde3cac5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 14 Oct 2024 12:15:56 +0800 Subject: [PATCH] refactor(prompt): improve handling of variable templates in advanced prompt transform --- api/core/prompt/advanced_prompt_transform.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f5cea3cecb72d..49818b69d24192 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -145,11 +145,20 @@ def _get_chat_model_prompt_messages( raw_prompt = prompt_item.text if prompt_item.edition_type == "basic" or not prompt_item.edition_type: - vp = VariablePool() - for k, v in inputs.items(): - vp.add(k[1:-1].split("."), v) - raw_prompt.replace("{{#context#}}", context or "") - prompt = vp.convert_template(raw_prompt).text + if self.with_variable_tmpl: + vp = VariablePool() + for k, v in inputs.items(): + if k.startswith("#"): + vp.add(k[1:-1].split("."), v) + raw_prompt.replace("{{#context#}}", context or "") + prompt = vp.convert_template(raw_prompt).text + else: + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs = self._set_context_variable( + context=context, parser=parser, prompt_inputs=prompt_inputs + ) + prompt = parser.format(prompt_inputs) elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs