From 4f5c052dc823b6d5c1e389b1c7b1db6f47f2327a Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 19:15:11 +0800 Subject: [PATCH 1/6] fix single step run error --- api/services/workflow_service.py | 64 +++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2c9c07106cec94..55f2526fbfc827 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -270,28 +270,48 @@ def run_draft_workflow_node(self, app_model: App, return workflow_node_execution - # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.utcnow(), - finished_at=datetime.utcnow() - ) + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, + process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, + outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, + execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) + if node_run_result.metadata else None), + status=WorkflowNodeExecutionStatus.SUCCEEDED.value, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + else: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + status=node_run_result.status.value, + error=node_run_result.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) db.session.add(workflow_node_execution) db.session.commit() From 3f59a579d78b1a92ca95d1031a6ba54f172c5261 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 22:12:03 +0800 Subject: [PATCH 2/6] add llm node --- api/core/app/apps/base_app_runner.py | 31 +- .../easy_ui_based_generate_task_pipeline.py | 83 +--- api/core/model_manager.py | 4 +- api/core/prompt/advanced_prompt_transform.py | 51 ++- .../entities}/__init__.py | 0 .../entities/advanced_prompt_entities.py | 42 ++ api/core/prompt/prompt_transform.py | 19 +- api/core/prompt/simple_prompt_transform.py | 11 + api/core/prompt/utils/prompt_message_util.py | 85 ++++ api/core/workflow/entities/node_entities.py | 2 +- api/core/workflow/nodes/answer/__init__.py | 0 .../answer_node.py} | 8 +- .../{direct_answer => answer}/entities.py | 4 +- api/core/workflow/nodes/llm/entities.py | 45 ++- api/core/workflow/nodes/llm/llm_node.py | 370 +++++++++++++++++- api/core/workflow/workflow_engine_manager.py | 47 +-- .../prompt/test_advanced_prompt_transform.py | 77 ++-- 17 files changed, 697 insertions(+), 182 deletions(-) rename api/core/{workflow/nodes/direct_answer => prompt/entities}/__init__.py (100%) create mode 100644 api/core/prompt/entities/advanced_prompt_entities.py create mode 100644 api/core/prompt/utils/prompt_message_util.py create mode 100644 api/core/workflow/nodes/answer/__init__.py rename api/core/workflow/nodes/{direct_answer/direct_answer_node.py => answer/answer_node.py} (91%) rename api/core/workflow/nodes/{direct_answer => answer}/entities.py (75%) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e7ce7f25ef51a4..868e9e724f4081 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -23,7 +23,8 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation @@ -155,13 +156,39 @@ def organize_prompt_messages(self, app_record: App, model_config=model_config ) else: + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False + ) + ) + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + prompt_template = CompletionModelPromptTemplate( + text=advanced_completion_prompt_template.prompt + ) + + memory_config.role_prefix = MemoryConfig.RolePrefix( + user=advanced_completion_prompt_template.role_prefix.user, + assistant=advanced_completion_prompt_template.role_prefix.assistant + ) + else: + prompt_template = [] + for message in prompt_template_entity.advanced_chat_prompt_template.messages: + prompt_template.append(ChatModelMessage( + text=message.text, + role=message.role + )) + prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query if query else '', files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) diff --git a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 856bfb623d0e22..412029b02491f8 100644 --- a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -30,17 +30,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.moderation.output_moderation import ModerationRule, OutputModeration -from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created @@ -438,7 +433,10 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() - self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) + self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + self._model_config.mode, + self._task_state.llm_result.prompt_messages + ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -582,77 +580,6 @@ def _yield_response(self, response: dict) -> str: """ return "data: " + json.dumps(response) + "\n\n" - def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]: - """ - Prompt messages to prompt for saving. - :param prompt_messages: prompt messages - :return: - """ - prompts = [] - if self._model_config.mode == ModelMode.CHAT.value: - for prompt_message in prompt_messages: - if prompt_message.role == PromptMessageRole.USER: - role = 'user' - elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' - elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' - else: - continue - - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - prompts.append({ - "role": role, - "text": text, - "files": files - }) - else: - prompt_message = prompt_messages[0] - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - params = { - "role": 'user', - "text": text, - } - - if files: - params['files'] = files - - prompts.append(params) - - return prompts - def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index aa16cf866f9327..8c0633992767dc 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -24,11 +24,11 @@ class ModelInstance: """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: - self._provider_model_bundle = provider_model_bundle + self.provider_model_bundle = provider_model_bundle self.model = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) - self.model_type_instance = self._provider_model_bundle.model_type_instance + self.model_type_instance = self.provider_model_bundle.model_type_instance def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 48b0d8ba021e03..60c77e943b3322 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,5 @@ -from typing import Optional +from typing import Optional, Union -from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -12,6 +11,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform): Advanced Prompt Transform for Workflow LLM Node. """ - def get_prompt(self, prompt_template_entity: PromptTemplateEntity, + def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], inputs: dict, query: str, files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: prompt_messages = [] @@ -34,21 +35,23 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: prompt_messages = self._get_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) @@ -56,17 +59,18 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, return prompt_messages def _get_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, + prompt_template: CompletionModelPromptTemplate, inputs: dict, query: Optional[str], files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt + raw_prompt = prompt_template.text prompt_messages = [] @@ -75,15 +79,17 @@ def _get_completion_model_prompt_messages(self, prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - prompt_inputs = self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config - ) + if memory and memory_config: + role_prefix = memory_config.role_prefix + prompt_inputs = self._set_histories_variable( + memory=memory, + memory_config=memory_config, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) @@ -104,17 +110,18 @@ def _get_completion_model_prompt_messages(self, return prompt_messages def _get_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, + prompt_template: list[ChatModelMessage], inputs: dict, query: Optional[str], files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages + raw_prompt_list = prompt_template prompt_messages = [] @@ -137,8 +144,8 @@ def _get_chat_model_prompt_messages(self, elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) - if memory: - prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) + if memory and memory_config: + prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) if files: prompt_message_contents = [TextPromptMessageContent(data=query)] @@ -195,8 +202,9 @@ def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, return prompt_inputs def _set_histories_variable(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, + role_prefix: MemoryConfig.RolePrefix, prompt_template: PromptTemplateParser, prompt_inputs: dict, model_config: ModelConfigWithCredentialsEntity) -> dict: @@ -213,6 +221,7 @@ def _set_histories_variable(self, memory: TokenBufferMemory, histories = self._get_history_messages_from_memory( memory=memory, + memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, ai_prefix=role_prefix.assistant diff --git a/api/core/workflow/nodes/direct_answer/__init__.py b/api/core/prompt/entities/__init__.py similarity index 100% rename from api/core/workflow/nodes/direct_answer/__init__.py rename to api/core/prompt/entities/__init__.py diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py new file mode 100644 index 00000000000000..97ac2e3e2a8651 --- /dev/null +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -0,0 +1,42 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """ + Chat Message. + """ + text: str + role: PromptMessageRole + + +class CompletionModelPromptTemplate(BaseModel): + """ + Completion Model Prompt Template. + """ + text: str + + +class MemoryConfig(BaseModel): + """ + Memory Config. + """ + class RolePrefix(BaseModel): + """ + Role Prefix. + """ + user: str + assistant: str + + class WindowConfig(BaseModel): + """ + Window Config. + """ + enabled: bool + size: Optional[int] = None + + role_prefix: Optional[RolePrefix] = None + window: WindowConfig diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 02e91d91128629..9bf2ae090f7686 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -5,19 +5,22 @@ from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -44,6 +47,7 @@ def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_conf return rest_tokens def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, max_token_limit: int, human_prefix: Optional[str] = None, ai_prefix: Optional[str] = None) -> str: @@ -58,13 +62,22 @@ def _get_history_messages_from_memory(self, memory: TokenBufferMemory, if ai_prefix: kwargs['ai_prefix'] = ai_prefix + if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: + kwargs['message_limit'] = memory_config.window.size + return memory.get_history_prompt_text( **kwargs ) def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, max_token_limit: int) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( - max_token_limit=max_token_limit + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if (memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0) + else 10 ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index ca0efb200c15b1..613716c2cf3b6c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -13,6 +13,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode @@ -182,6 +183,11 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, if memory: prompt_messages = self._append_chat_histories( memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), prompt_messages=prompt_messages, model_config=model_config ) @@ -220,6 +226,11 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), max_token_limit=rest_tokens, ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py new file mode 100644 index 00000000000000..5fceeb3595c9d7 --- /dev/null +++ b/api/core/prompt/utils/prompt_message_util.py @@ -0,0 +1,85 @@ +from typing import cast + +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) +from core.prompt.simple_prompt_transform import ModelMode + + +class PromptMessageUtil: + @staticmethod + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + """ + Prompt messages to prompt for saving. + :param model_mode: model mode + :param prompt_messages: prompt messages + :return: + """ + prompts = [] + if model_mode == ModelMode.CHAT.value: + for prompt_message in prompt_messages: + if prompt_message.role == PromptMessageRole.USER: + role = 'user' + elif prompt_message.role == PromptMessageRole.ASSISTANT: + role = 'assistant' + elif prompt_message.role == PromptMessageRole.SYSTEM: + role = 'system' + else: + continue + + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + prompts.append({ + "role": role, + "text": text, + "files": files + }) + else: + prompt_message = prompt_messages[0] + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + params = { + "role": 'user', + "text": text, + } + + if files: + params['files'] = files + + prompts.append(params) + + return prompts diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 263172da31b88e..befabfb3b4e333 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -12,7 +12,7 @@ class NodeType(Enum): """ START = 'start' END = 'end' - DIRECT_ANSWER = 'direct-answer' + ANSWER = 'answer' LLM = 'llm' KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' IF_ELSE = 'if-else' diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/answer/answer_node.py similarity index 91% rename from api/core/workflow/nodes/direct_answer/direct_answer_node.py rename to api/core/workflow/nodes/answer/answer_node.py index 22ef2ed53b7e61..381ada1a1e52d0 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -5,14 +5,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.nodes.answer.entities import AnswerNodeData from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData from models.workflow import WorkflowNodeExecutionStatus -class DirectAnswerNode(BaseNode): - _node_data_cls = DirectAnswerNodeData - node_type = NodeType.DIRECT_ANSWER +class AnswerNode(BaseNode): + _node_data_cls = AnswerNodeData + node_type = NodeType.ANSWER def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ diff --git a/api/core/workflow/nodes/direct_answer/entities.py b/api/core/workflow/nodes/answer/entities.py similarity index 75% rename from api/core/workflow/nodes/direct_answer/entities.py rename to api/core/workflow/nodes/answer/entities.py index e7c11e3c4d1d2e..7c6fed3e4ea6f4 100644 --- a/api/core/workflow/nodes/direct_answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -2,9 +2,9 @@ from core.workflow.entities.variable_entities import VariableSelector -class DirectAnswerNodeData(BaseNodeData): +class AnswerNodeData(BaseNodeData): """ - DirectAnswer Node Data. + Answer Node Data. """ variables: list[VariableSelector] = [] answer: str diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bd499543d903aa..67163c93cd2b19 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,8 +1,51 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class ContextConfig(BaseModel): + """ + Context Config. + """ + enabled: bool + variable_selector: Optional[list[str]] = None + + +class VisionConfig(BaseModel): + """ + Vision Config. + """ + class Configs(BaseModel): + """ + Configs. + """ + detail: Literal['low', 'high'] + + enabled: bool + configs: Optional[Configs] = None class LLMNodeData(BaseNodeData): """ LLM Node Data. """ - pass + model: ModelConfig + variables: list[VariableSelector] = [] + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] + memory: Optional[MemoryConfig] = None + context: ContextConfig + vision: VisionConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 41e28937ac7d2a..d1050a5f5b366d 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,10 +1,27 @@ +from collections.abc import Generator from typing import Optional, cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData +from extensions.ext_database import db +from models.model import Conversation +from models.workflow import WorkflowNodeExecutionStatus class LLMNode(BaseNode): @@ -20,7 +37,341 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - pass + node_inputs = None + process_data = None + + try: + # fetch variables and fetch values from variable pool + inputs = self._fetch_inputs(node_data, variable_pool) + + node_inputs = { + **inputs + } + + # fetch files + files: list[FileObj] = self._fetch_files(node_data, variable_pool) + + if files: + node_inputs['#files#'] = [{ + 'type': file.type.value, + 'transfer_method': file.transfer_method.value, + 'url': file.url, + 'upload_file_id': file.upload_file_id, + } for file in files] + + # fetch context value + context = self._fetch_context(node_data, variable_pool) + + if context: + node_inputs['#context#'] = context + + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + + # fetch memory + memory = self._fetch_memory(node_data, variable_pool, model_instance) + + # fetch prompt messages + prompt_messages, stop = self._fetch_prompt_messages( + node_data=node_data, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ) + } + + # handle invoke result + result_text, usage = self._invoke_llm( + node_data=node_data, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop + ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) + + outputs = { + 'text': result_text, + 'usage': jsonable_encoder(usage) + } + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) + + def _invoke_llm(self, node_data: LLMNodeData, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str]) -> tuple[str, LLMUsage]: + """ + Invoke large language model + :param node_data: node data + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data.model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + return self._handle_invoke_result( + invoke_result=invoke_result + ) + + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + model = None + prompt_messages = [] + full_text = '' + usage = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + self.publish_text_chunk(text=text) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + return full_text, usage + + def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + inputs = {} + for variable_selector in node_data.variables: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.value_selector} not found') + + inputs[variable_selector.variable] = variable_value + + return inputs + + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]: + """ + Fetch files + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.vision.enabled: + return [] + + files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + if not files: + return [] + + return files + + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + """ + Fetch context + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.context.enabled: + return None + + context_value = variable_pool.get_variable_value(node_data.context.variable_selector) + if context_value: + if isinstance(context_value, str): + return context_value + elif isinstance(context_value, list): + context_str = '' + for item in context_value: + if 'content' not in item: + raise ValueError(f'Invalid context structure: {item}') + + context_str += item['content'] + '\n' + + return context_str.strip() + + return None + + def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.model.name + provider_name = node_data.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory(self, node_data: LLMNodeData, + variable_pool: VariablePool, + model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + """ + Fetch memory + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.memory: + return None + + # get conversation id + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) + if conversation_id is None: + return None + + # get conversation + conversation = db.session.query(Conversation).filter( + Conversation.tenant_id == self.tenant_id, + Conversation.app_id == self.app_id, + Conversation.id == conversation_id + ).first() + + if not conversation: + return None + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + return memory + + def _fetch_prompt_messages(self, node_data: LLMNodeData, + inputs: dict[str, str], + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Fetch prompt messages + :param node_data: node data + :param inputs: inputs + :param files: files + :param context: context + :param memory: memory + :param model_config: model config + :return: + """ + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=node_data.prompt_template, + inputs=inputs, + query='', + files=files, + context=context, + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + stop = model_config.stop + + return prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: @@ -29,9 +380,20 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) :param node_data: node data :return: """ - # TODO extract variable selector to variable mapping for single step debugging - return {} + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + if node_data.context.enabled: + variable_mapping['#context#'] = node_data.context.variable_selector + + if node_data.vision.enabled: + variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + return variable_mapping @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 17225c19ea0bb9..49b9d4ac4d7b4c 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -7,9 +7,9 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode @@ -24,13 +24,12 @@ from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, - WorkflowType, ) node_classes = { NodeType.START: StartNode, NodeType.END: EndNode, - NodeType.DIRECT_ANSWER: DirectAnswerNode, + NodeType.ANSWER: AnswerNode, NodeType.LLM: LLMNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.IF_ELSE: IfElseNode, @@ -156,7 +155,7 @@ def run_workflow(self, workflow: Workflow, callbacks=callbacks ) - if next_node.node_type == NodeType.END: + if next_node.node_type in [NodeType.END, NodeType.ANSWER]: break predecessor_node = next_node @@ -402,10 +401,16 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, # add to workflow_nodes_and_results workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) + try: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool + ) + except Exception as e: + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed @@ -420,9 +425,6 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - # set end node output if in chat - self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) - workflow_nodes_and_result.result = node_run_result # node run success @@ -453,29 +455,6 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, db.session.close() - def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, - node: BaseNode, - node_run_result: NodeRunResult) -> None: - """ - Set end node output if in chat - :param workflow_run_state: workflow run state - :param node: current node - :param node_run_result: node run result - :return: - """ - if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END: - workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] - if workflow_nodes_and_result_before_end: - if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: - if not node_run_result.outputs: - node_run_result.outputs = {} - - node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') - elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER: - if not node_run_result.outputs: - node_run_result.outputs = {} - - node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer') def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 4357c6405c8a3d..5c08b9f168ad20 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,12 +2,12 @@ import pytest -from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ - ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity +from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages(): model_config_mock.model = 'gpt-3.5-turbo-instruct' prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_entity = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( - prompt=prompt_template, - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) + prompt_template_config = CompletionModelPromptTemplate( + text=prompt_template + ) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix( + user="Human", + assistant="Assistant" + ), + window=MemoryConfig.WindowConfig( + enabled=False ) ) + inputs = { "name": "John" } @@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages(): prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template_config, inputs=inputs, query=None, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config_mock ) @@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages(): def test__get_chat_model_prompt_messages(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, memory_config, messages, inputs, context = get_chat_model_args files = [] query = "Hi2." @@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config_mock ) @@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) assert prompt_messages[5].content == query def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, _, messages, inputs, context = get_chat_model_args files = [] prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=None, files=files, context=context, + memory_config=None, memory=None, model_config=model_config_mock ) @@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, _, messages, inputs, context = get_chat_model_args files = [ FileObj( @@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=None, files=files, context=context, + memory_config=None, memory=None, model_config=model_config_mock ) @@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 @@ -173,22 +181,31 @@ def get_chat_model_args(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' - prompt_template_entity = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( - messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ] + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False ) ) + prompt_messages = [ + ChatModelMessage( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM + ), + ChatModelMessage( + text="Hi.", + role=PromptMessageRole.USER + ), + ChatModelMessage( + text="Hello!", + role=PromptMessageRole.ASSISTANT + ) + ] + inputs = { "name": "John" } context = "I am superman." - return model_config_mock, prompt_template_entity, inputs, context + return model_config_mock, memory_config, prompt_messages, inputs, context From 3bd53556ca74b6c3545789d0d2d772799d6c2ea8 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 22:41:59 +0800 Subject: [PATCH 3/6] feat: javascript code --- api/.env.example | 2 +- .../helper/code_executor/code_executor.py | 8 ++- .../code_executor/javascript_transformer.py | 54 ++++++++++++++++++- api/core/workflow/nodes/code/code_node.py | 17 ++++-- api/core/workflow/nodes/code/entities.py | 2 +- 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/api/.env.example b/api/.env.example index 4a3b1d65afdfc0..c0942412ab948f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -135,4 +135,4 @@ BATCH_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT= -CODE_EXECUTINO_API_KEY= +CODE_EXECUTION_API_KEY= diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 21a8ca5f9f1ba7..adfdf6cc69f57f 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from httpx import post from pydantic import BaseModel from yarl import URL +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer @@ -39,17 +40,20 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code template_transformer = PythonTemplateTransformer elif language == 'jinja2': template_transformer = Jinja2TemplateTransformer + elif language == 'javascript': + template_transformer = NodeJsTemplateTransformer else: raise CodeExecutionException('Unsupported language') runner = template_transformer.transform_caller(code, inputs) - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' headers = { 'X-Api-Key': CODE_EXECUTION_API_KEY } data = { - 'language': language if language != 'jinja2' else 'python3', + 'language': 'python3' if language == 'jinja2' else + 'nodejs' if language == 'javascript' else + 'python3' if language == 'python3' else None, 'code': runner, } diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py index f87f5c14cbbd7d..cc6ad16c66d833 100644 --- a/api/core/helper/code_executor/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -1 +1,53 @@ -# TODO \ No newline at end of file +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +NODEJS_RUNNER = """// declare main function here +{{code}} + +// execute main function, and return the result +// inputs is a dict, unstructured inputs +output = main({{inputs}}) + +// convert output to json and print +output = JSON.stringify(output) + +result = `<>${output}<>` + +console.log(result) +""" + + +class NodeJsTemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = NODEJS_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + return json.loads(result) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2c11e5ba00b9ac..5dfe398711528f 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -15,6 +15,16 @@ MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 +JAVASCRIPT_DEFAULT_CODE = """function main({args1, args2}) { + return { + result: args1 + args2 + } +}""" + +PYTHON_DEFAULT_CODE = """def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + }""" class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -42,9 +52,7 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } ], "code_language": "javascript", - "code": "async function main(arg1, arg2) {\n return new Promise((resolve, reject) => {" - "\n if (true) {\n resolve({\n \"result\": arg1 + arg2" - "\n });\n } else {\n reject(\"e\");\n }\n });\n}", + "code": JAVASCRIPT_DEFAULT_CODE, "outputs": [ { "variable": "result", @@ -68,8 +76,7 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } ], "code_language": "python3", - "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " - "+ arg2\n }", + "code": PYTHON_DEFAULT_CODE, "outputs": [ { "variable": "result", diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index d4d76c45f9f879..97e178f5df9112 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -17,4 +17,4 @@ class Output(BaseModel): variables: list[VariableSelector] code_language: Literal['python3', 'javascript'] code: str - outputs: dict[str, Output] + outputs: dict[str, Output] \ No newline at end of file From 856466320d5b2de27015fe7c35283bbbeb222d9f Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 22:42:28 +0800 Subject: [PATCH 4/6] fix: linter --- api/core/helper/code_executor/code_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index adfdf6cc69f57f..9d74edee0e5248 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,8 +4,8 @@ from httpx import post from pydantic import BaseModel from yarl import URL -from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer From 4d7caa345809448eb49553a5b4da3053c141b843 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 23:08:14 +0800 Subject: [PATCH 5/6] add llm node test --- .../workflow/nodes/__init__.py | 0 .../workflow/nodes/test_llm.py | 132 ++++++++++++++++++ .../workflow/nodes/test_template_transform.py | 4 +- .../core/workflow/nodes/__init__.py | 0 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/__init__.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_llm.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/__init__.py diff --git a/api/tests/integration_tests/workflow/nodes/__init__.py b/api/tests/integration_tests/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py new file mode 100644 index 00000000000000..18fba566bf7f71 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -0,0 +1,132 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderModelBundle, ProviderConfiguration +from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, CustomProviderConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.llm.llm_node import LLMNode +from extensions.ext_database import db +from models.provider import ProviderType +from models.workflow import WorkflowNodeExecutionStatus + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_execute_llm(setup_openai_mock): + node = LLMNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'llm', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5.turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'variables': [ + { + 'variable': 'weather', + 'value_selector': ['abc', 'output'], + }, + { + 'variable': 'query', + 'value_selector': ['sys', 'query'] + } + ], + 'prompt_template': [ + { + 'role': 'system', + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.' + }, + { + 'role': 'user', + 'text': '{{query}}' + } + ], + 'memory': { + 'window': { + 'enabled': True, + 'size': 2 + } + }, + 'context': { + 'enabled': False + }, + 'vision': { + 'enabled': False + } + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather today?', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION: 'abababa' + }, user_inputs={}) + pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + + credentials = { + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + + provider_instance = ModelProviderFactory().get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ) + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_config = ModelConfigWithCredentialsEntity( + model='gpt-3.5-turbo', + provider='openai', + mode='chat', + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), + provider_model_bundle=provider_model_bundle + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['text'] is not None + assert result.outputs['usage']['total_tokens'] > 0 diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 4348995a055026..36cf0a070aa855 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,7 +1,7 @@ import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -14,7 +14,7 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.END_USER, config={ 'id': '1', 'data': { diff --git a/api/tests/unit_tests/core/workflow/nodes/__init__.py b/api/tests/unit_tests/core/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 From 5fe0d50cee095a78958045171e6e657ba54074ca Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 00:08:13 +0800 Subject: [PATCH 6/6] add deduct quota for llm node --- api/core/workflow/nodes/llm/llm_node.py | 56 ++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index d1050a5f5b366d..9285bbe74e87a1 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -3,6 +3,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -21,6 +22,7 @@ from core.workflow.nodes.llm.entities import LLMNodeData from extensions.ext_database import db from models.model import Conversation +from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus @@ -144,10 +146,15 @@ def _invoke_llm(self, node_data: LLMNodeData, ) # handle invoke result - return self._handle_invoke_result( + text, usage = self._handle_invoke_result( invoke_result=invoke_result ) + # deduct quota + self._deduct_llm_quota(model_instance=model_instance, usage=usage) + + return text, usage + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ Handle invoke result @@ -373,6 +380,53 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, return prompt_messages, stop + def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if 'gpt-4' in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """