From d8ab611480165f47acc2be17b02203d188c76acf Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 17 Mar 2024 21:08:25 +0800 Subject: [PATCH 1/2] fix: code --- .../workflow_event_trigger_callback.py | 2 ++ .../workflow/workflow_event_trigger_callback.py | 2 ++ api/core/app/entities/queue_entities.py | 1 + .../app/task_pipeline/workflow_cycle_manage.py | 14 ++++++++++++-- api/core/helper/code_executor/code_executor.py | 2 +- .../helper/code_executor/python_transformer.py | 2 +- .../workflow/callbacks/base_workflow_callback.py | 1 + api/core/workflow/nodes/code/code_node.py | 6 +++--- api/core/workflow/workflow_engine_manager.py | 1 + 9 files changed, 24 insertions(+), 7 deletions(-) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 972fda2d49a66c..45d0e94bfb52fc 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -97,6 +97,7 @@ def on_workflow_node_execute_failed(self, node_id: str, node_data: BaseNodeData, error: str, inputs: Optional[dict] = None, + outputs: Optional[dict] = None, process_data: Optional[dict] = None) -> None: """ Workflow node execute failed @@ -107,6 +108,7 @@ def on_workflow_node_execute_failed(self, node_id: str, node_type=node_type, node_data=node_data, inputs=inputs, + outputs=outputs, process_data=process_data, error=error ), diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index e5a8e8d3747c42..e15ebd55485d5f 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -96,6 +96,7 @@ def on_workflow_node_execute_failed(self, node_id: str, node_data: BaseNodeData, error: str, inputs: Optional[dict] = None, + outputs: Optional[dict] = None, process_data: Optional[dict] = None) -> None: """ Workflow node execute failed @@ -106,6 +107,7 @@ def on_workflow_node_execute_failed(self, node_id: str, node_type=node_type, node_data=node_data, inputs=inputs, + outputs=outputs, process_data=process_data, error=error ), diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5c31996fd345a6..bf174e30e1ae58 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -168,6 +168,7 @@ class QueueNodeFailedEvent(AppQueueEvent): node_data: BaseNodeData inputs: Optional[dict] = None + outputs: Optional[dict] = None process_data: Optional[dict] = None error: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 1af2074c056e07..54bfe50a382ca4 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -218,7 +218,11 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, start_at: float, - error: str) -> WorkflowNodeExecution: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + ) -> WorkflowNodeExecution: """ Workflow node execution failed :param workflow_node_execution: workflow node execution @@ -230,6 +234,9 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE workflow_node_execution.error = error workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.finished_at = datetime.utcnow() + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None db.session.commit() db.session.refresh(workflow_node_execution) @@ -402,7 +409,10 @@ def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailed workflow_node_execution = self._workflow_node_execution_failed( workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, - error=event.error + error=event.error, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs ) db.session.close() diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 9d74edee0e5248..a96a2f12787f07 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -72,7 +72,7 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code response = response.json() except: raise CodeExecutionException('Failed to parse response') - + response = CodeExecutionResponse(**response) if response.code != 0: diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py index 27863ee4435354..257aa4a8f60770 100644 --- a/api/core/helper/code_executor/python_transformer.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -48,7 +48,7 @@ def transform_response(cls, response: str) -> dict: :return: """ # extract result - result = re.search(r'<>(.*)<>', response, re.DOTALL) + result = re.search(r'<>(.*?)<>', response, re.DOTALL) if not result: raise ValueError('Failed to parse result') result = result.group(1) diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 1f5472b430c96a..c2546050c5e8d6 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -57,6 +57,7 @@ def on_workflow_node_execute_failed(self, node_id: str, node_data: BaseNodeData, error: str, inputs: Optional[dict] = None, + outputs: Optional[dict] = None, process_data: Optional[dict] = None) -> None: """ Workflow node execute failed diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 01e4fc458311fd..ac9683edcc6b2b 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -11,19 +11,19 @@ MIN_NUMBER = -2 ** 63 MAX_PRECISION = 20 MAX_DEPTH = 5 -MAX_STRING_LENGTH = 1000 +MAX_STRING_LENGTH = 5000 MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) { return { - result: args1 + args2 + result: arg1 + arg2 } }""" PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict: return { - "result": args1 + args2, + "result": arg1 + arg2, }""" class CodeNode(BaseNode): diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 143533810e2a94..99ebf7c72ef8db 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -429,6 +429,7 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node_data=node.node_data, error=node_run_result.error, inputs=node_run_result.inputs, + outputs=node_run_result.outputs, process_data=node_run_result.process_data, ) From 80f1fbba566c88c15413f28240d6c552dd3b1a6c Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 17 Mar 2024 21:26:58 +0800 Subject: [PATCH 2/2] add image file as markdown stream outupt --- api/controllers/files/tool_files.py | 2 +- api/controllers/service_api/app/message.py | 2 +- api/controllers/web/message.py | 2 +- api/core/app/app_config/entities.py | 4 +- .../features/file_upload/manager.py | 6 +- .../app/apps/advanced_chat/app_generator.py | 6 +- .../advanced_chat/generate_task_pipeline.py | 31 +++++--- api/core/app/apps/agent_chat/app_generator.py | 6 +- api/core/app/apps/base_app_runner.py | 6 +- api/core/app/apps/chat/app_generator.py | 6 +- api/core/app/apps/completion/app_generator.py | 12 +-- .../app/apps/message_based_app_generator.py | 2 +- api/core/app/apps/workflow/app_generator.py | 6 +- api/core/app/entities/app_invoke_entities.py | 4 +- api/core/app/entities/task_entities.py | 2 + .../app/task_pipeline/message_cycle_manage.py | 7 +- .../task_pipeline/workflow_cycle_manage.py | 62 +++++++++++++-- api/core/file/file_obj.py | 65 ++++++++++++++-- api/core/file/message_file_parser.py | 44 +++++------ api/core/file/upload_file_parser.py | 9 ++- api/core/memory/token_buffer_memory.py | 8 +- api/core/prompt/advanced_prompt_transform.py | 8 +- api/core/prompt/simple_prompt_transform.py | 10 +-- api/core/tools/tool_file_manager.py | 19 ++--- api/core/workflow/entities/variable_pool.py | 3 +- api/core/workflow/nodes/llm/llm_node.py | 15 ++-- api/core/workflow/nodes/tool/tool_node.py | 62 ++++++++------- api/fields/conversation_fields.py | 2 +- api/fields/message_fields.py | 2 +- api/models/model.py | 75 ++++++++++++++++++- api/services/workflow/workflow_converter.py | 4 +- .../prompt/test_advanced_prompt_transform.py | 8 +- 32 files changed, 341 insertions(+), 159 deletions(-) diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 0a254c1699f73c..5a07ad2ea51800 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -27,7 +27,7 @@ def get(self, file_id, extension): raise Forbidden('Invalid request.') try: - result = ToolFileManager.get_file_generator_by_message_file_id( + result = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 4e96a924b090e5..703ff6e2581a32 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -54,7 +54,7 @@ class MessageListApi(Resource): 'conversation_id': fields.String, 'inputs': fields.Raw, 'query': fields.String, - 'answer': fields.String, + 'answer': fields.String(attribute='re_sign_file_url_answer'), 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 51a48ee9fbe523..3de17670586c6c 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -61,7 +61,7 @@ class MessageListApi(WebApiResource): 'conversation_id': fields.String, 'inputs': fields.Raw, 'query': fields.String, - 'answer': fields.String, + 'answer': fields.String(attribute='re_sign_file_url_answer'), 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 6a521dfcc5b7b5..101e25d5821e8b 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -183,7 +183,7 @@ class TextToSpeechEntity(BaseModel): language: Optional[str] = None -class FileUploadEntity(BaseModel): +class FileExtraConfig(BaseModel): """ File Upload Entity. """ @@ -191,7 +191,7 @@ class FileUploadEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileUploadEntity] = None + file_upload: Optional[FileExtraConfig] = None opening_statement: Optional[str] = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 63830696ffd28c..4bfb3e21b38bcb 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,11 +1,11 @@ from typing import Optional -from core.app.app_config.entities import FileUploadEntity +from core.app.app_config.entities import FileExtraConfig class FileUploadConfigManager: @classmethod - def convert(cls, config: dict) -> Optional[FileUploadEntity]: + def convert(cls, config: dict) -> Optional[FileExtraConfig]: """ Convert model config to model config @@ -15,7 +15,7 @@ def convert(cls, config: dict) -> Optional[FileUploadEntity]: if file_upload_dict: if 'image' in file_upload_dict and file_upload_dict['image']: if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: - return FileUploadEntity( + return FileExtraConfig( image_config={ 'number_limits': file_upload_dict['image']['number_limits'], 'detail': file_upload_dict['image']['detail'], diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 30b583ab06897a..6c7b37c7c67ce0 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -67,11 +67,11 @@ def generate(self, app_model: App, # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) - if file_upload_entity: + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict) + if file_extra_config: file_objs = message_file_parser.validate_and_transform_files_arg( files, - file_upload_entity, + file_extra_config, user ) else: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 77801e8dc34af0..1d8558ee743ad7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Generator @@ -11,7 +12,6 @@ QueueAdvancedChatMessageEndEvent, QueueAnnotationReplyEvent, QueueErrorEvent, - QueueMessageFileEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, @@ -34,6 +34,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage +from core.file.file_obj import FileVar from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeType, SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode @@ -260,10 +261,10 @@ def _process_stream_response(self) -> Generator: annotation = self._handle_annotation_reply(event) if annotation: self._task_state.answer = annotation.content - elif isinstance(event, QueueMessageFileEvent): - response = self._message_file_to_stream_response(event) - if response: - yield response + # elif isinstance(event, QueueMessageFileEvent): + # response = self._message_file_to_stream_response(event) + # if response: + # yield response elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -464,10 +465,22 @@ def _generate_stream_outputs_when_node_finished(self) -> None: text = None if isinstance(value, str | int | float): text = str(value) - elif isinstance(value, object): # TODO FILE - # convert file to markdown - text = f'![]({value.get("url")})' - pass + elif isinstance(value, dict | list): + # handle files + file_vars = self._fetch_files_from_variable_value(value) + for file_var in file_vars: + try: + file_var_obj = FileVar(**file_var) + except Exception as e: + logger.error(f'Error creating file var: {e}') + continue + + # convert file to markdown + text = file_var_obj.to_markdown() + + if not text: + # other types + text = json.dumps(value, ensure_ascii=False) if text: for token in text: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index f3f439b12df104..0e0ff458dcd3ca 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -81,11 +81,11 @@ def generate(self, app_model: App, # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_upload_entity: + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: file_objs = message_file_parser.validate_and_transform_files_arg( files, - file_upload_entity, + file_extra_config, user ) else: diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 868e9e724f4081..3ecd3f4375eaac 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -14,7 +14,7 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.file.file_obj import FileObj +from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage @@ -33,7 +33,7 @@ def get_pre_calculate_rest_tokens(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileObj], + files: list[FileVar], query: Optional[str] = None) -> int: """ Get pre calculate rest tokens @@ -125,7 +125,7 @@ def organize_prompt_messages(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileObj], + files: list[FileVar], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 3d3ee7e446accb..6bf309ca1b50b1 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -81,11 +81,11 @@ def generate(self, app_model: App, # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_upload_entity: + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: file_objs = message_file_parser.validate_and_transform_files_arg( files, - file_upload_entity, + file_extra_config, user ) else: diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index ad979eb8404ee4..b15e4b4871554d 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -76,11 +76,11 @@ def generate(self, app_model: App, # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_upload_entity: + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: file_objs = message_file_parser.validate_and_transform_files_arg( files, - file_upload_entity, + file_extra_config, user ) else: @@ -233,11 +233,11 @@ def generate_more_like_this(self, app_model: App, # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_upload_entity: + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_extra_config: file_objs = message_file_parser.validate_and_transform_files_arg( message.files, - file_upload_entity, + file_extra_config, user ) else: diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 2d480d71564205..8c475b755feed9 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -226,7 +226,7 @@ def _init_generate_records(self, transfer_method=file.transfer_method.value, belongs_to='user', url=file.url, - upload_file_id=file.upload_file_id, + upload_file_id=file.related_id, created_by_role=('account' if account_id else 'end_user'), created_by=account_id or end_user_id, ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index b3721cfae97694..01b379264c31ef 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -50,11 +50,11 @@ def generate(self, app_model: App, # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) - if file_upload_entity: + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict) + if file_extra_config: file_objs = message_file_parser.validate_and_transform_files_arg( files, - file_upload_entity, + file_extra_config, user ) else: diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 01cbd7d2b2df47..c05a8a77d0544f 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -5,7 +5,7 @@ from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileObj +from core.file.file_obj import FileVar from core.model_runtime.entities.model_entities import AIModelEntity @@ -73,7 +73,7 @@ class AppGenerateEntity(BaseModel): app_config: AppConfig inputs: dict[str, str] - files: list[FileObj] = [] + files: list[FileVar] = [] user_id: str # extras diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 124f4759851aa7..2bd92b87e232be 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -204,6 +204,7 @@ class Data(BaseModel): total_steps: int created_at: int finished_at: int + files: Optional[list[dict]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -253,6 +254,7 @@ class Data(BaseModel): execution_metadata: Optional[dict] = None created_at: int finished_at: int + files: Optional[list[dict]] = [] event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 305b560f95d505..16eb3d4fc28f24 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -97,6 +97,11 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti ) if message_file: + # get tool file id + tool_file_id = message_file.url.split('/')[-1] + # trim extension + tool_file_id = tool_file_id.split('.')[0] + # get extension if '.' in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' @@ -105,7 +110,7 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti else: extension = '.bin' # add sign url - url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) return MessageFileStreamResponse( task_id=self._application_generate_entity.task_id, diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 54bfe50a382ca4..2fc94c72400bef 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -21,6 +21,7 @@ WorkflowStartStreamResponse, WorkflowTaskState, ) +from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable from extensions.ext_database import db @@ -93,7 +94,7 @@ def _workflow_run_success(self, workflow_run: WorkflowRun, start_at: float, total_tokens: int, total_steps: int, - outputs: Optional[dict] = None) -> WorkflowRun: + outputs: Optional[str] = None) -> WorkflowRun: """ Workflow run success :param workflow_run: workflow run @@ -244,7 +245,8 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE return workflow_node_execution - def _workflow_start_to_stream_response(self, task_id: str, workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + def _workflow_start_to_stream_response(self, task_id: str, + workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: """ Workflow start to stream response. :param task_id: task id @@ -262,7 +264,8 @@ def _workflow_start_to_stream_response(self, task_id: str, workflow_run: Workflo ) ) - def _workflow_finish_to_stream_response(self, task_id: str, workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + def _workflow_finish_to_stream_response(self, task_id: str, + workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: """ Workflow finish to stream response. :param task_id: task id @@ -283,7 +286,8 @@ def _workflow_finish_to_stream_response(self, task_id: str, workflow_run: Workfl total_tokens=workflow_run.total_tokens, total_steps=workflow_run.total_steps, created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()) + finished_at=int(workflow_run.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) ) ) @@ -310,7 +314,7 @@ def _workflow_node_start_to_stream_response(self, task_id: str, workflow_node_ex ) def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ - -> NodeFinishStreamResponse: + -> NodeFinishStreamResponse: """ Workflow node finish to stream response. :param task_id: task id @@ -334,7 +338,8 @@ def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_e elapsed_time=workflow_node_execution.elapsed_time, execution_metadata=workflow_node_execution.execution_metadata_dict, created_at=int(workflow_node_execution.created_at.timestamp()), - finished_at=int(workflow_node_execution.finished_at.timestamp()) + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) ) ) @@ -465,3 +470,48 @@ def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceed db.session.close() return workflow_run + + def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: + """ + Fetch files from node outputs + :param outputs_dict: node outputs dict + :return: + """ + files = [] + for output_var, output_value in outputs_dict.items(): + file_vars = self._fetch_files_from_variable_value(output_value) + if file_vars: + files.extend(file_vars) + + return files + + def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + files = [] + if isinstance(value, list): + for item in value: + file_var = self._get_file_var_from_value(item) + if file_var: + files.append(file_var) + elif isinstance(value, dict): + file_var = self._get_file_var_from_value(value) + if file_var: + files.append(file_var) + + return files + + def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: + """ + Get file var from value + :param value: variable value + :return: + """ + if isinstance(value, dict): + if '__variant' in value and value['__variant'] == FileVar.__class__.__name__: + return value + + return None diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index bd896719c21835..87c4bd4bfa116d 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -3,7 +3,8 @@ from pydantic import BaseModel -from core.app.app_config.entities import FileUploadEntity +from core.app.app_config.entities import FileExtraConfig +from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db @@ -44,27 +45,65 @@ def value_of(value): return member raise ValueError(f"No matching enum found for value '{value}'") -class FileObj(BaseModel): - id: Optional[str] + +class FileVar(BaseModel): + id: Optional[str] = None # message file id tenant_id: str type: FileType transfer_method: FileTransferMethod - url: Optional[str] - upload_file_id: Optional[str] - file_upload_entity: FileUploadEntity + url: Optional[str] = None # remote url + related_id: Optional[str] = None + extra_config: Optional[FileExtraConfig] = None + filename: Optional[str] = None + extension: Optional[str] = None + mime_type: Optional[str] = None + + def to_dict(self) -> dict: + return { + '__variant': self.__class__.__name__, + 'type': self.type.value, + 'transfer_method': self.transfer_method.value, + 'url': self.preview_url, + 'related_id': self.related_id, + 'filename': self.filename, + 'extension': self.extension, + 'mime_type': self.mime_type, + } + + def to_markdown(self) -> str: + """ + Convert file to markdown + :return: + """ + preview_url = self.preview_url + if self.type == FileType.IMAGE: + text = f'![{self.filename}]({self.preview_url})' + else: + text = f'[{self.filename or self.preview_url}]({self.preview_url})' + + return text @property def data(self) -> Optional[str]: + """ + Get image data, file signed url or base64 data + depending on config MULTIMODAL_SEND_IMAGE_FORMAT + :return: + """ return self._get_data() @property def preview_url(self) -> Optional[str]: + """ + Get signed preview url + :return: + """ return self._get_data(force_url=True) @property def prompt_message_content(self) -> ImagePromptMessageContent: if self.type == FileType.IMAGE: - image_config = self.file_upload_entity.image_config + image_config = self.extra_config.image_config return ImagePromptMessageContent( data=self.data, @@ -79,7 +118,7 @@ def _get_data(self, force_url: bool = False) -> Optional[str]: elif self.transfer_method == FileTransferMethod.LOCAL_FILE: upload_file = (db.session.query(UploadFile) .filter( - UploadFile.id == self.upload_file_id, + UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id ).first()) @@ -87,5 +126,15 @@ def _get_data(self, force_url: bool = False) -> Optional[str]: upload_file=upload_file, force_url=force_url ) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + # get extension + if '.' in self.url: + extension = f'.{self.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 9d122c41204308..06f21c880a4195 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -2,8 +2,8 @@ import requests -from core.app.app_config.entities import FileUploadEntity -from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType +from core.app.app_config.entities import FileExtraConfig +from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar from extensions.ext_database import db from models.account import Account from models.model import EndUser, MessageFile, UploadFile @@ -16,13 +16,13 @@ def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity: FileUploadEntity, - user: Union[Account, EndUser]) -> list[FileObj]: + def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig, + user: Union[Account, EndUser]) -> list[FileVar]: """ validate and transform files arg :param files: - :param file_upload_entity: + :param file_extra_config: :param user: :return: """ @@ -44,14 +44,14 @@ def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity raise ValueError('Missing file upload_file_id') # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_entity) + type_file_objs = self._to_file_objs(files, file_extra_config) # validate files new_files = [] for file_type, file_objs in type_file_objs.items(): if file_type == FileType.IMAGE: # parse and validate files - image_config = file_upload_entity.image_config + image_config = file_extra_config.image_config # check if image file feature is enabled if not image_config: @@ -79,7 +79,7 @@ def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity # get upload file from upload_file_id upload_file = (db.session.query(UploadFile) .filter( - UploadFile.id == file_obj.upload_file_id, + UploadFile.id == file_obj.related_id, UploadFile.tenant_id == self.tenant_id, UploadFile.created_by == user.id, UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), @@ -95,30 +95,30 @@ def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_upload_entity: FileUploadEntity) -> list[FileObj]: + def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: """ transform message files :param files: - :param file_upload_entity: + :param file_extra_config: :return: """ # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_entity) + type_file_objs = self._to_file_objs(files, file_extra_config) # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_upload_entity: FileUploadEntity) -> dict[FileType, list[FileObj]]: + file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: """ transform files to file objs :param files: - :param file_upload_entity: + :param file_extra_config: :return: """ - type_file_objs: dict[FileType, list[FileObj]] = { + type_file_objs: dict[FileType, list[FileVar]] = { # Currently only support image FileType.IMAGE: [] } @@ -132,7 +132,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], if file.belongs_to == FileBelongsTo.ASSISTANT.value: continue - file_obj = self._to_file_obj(file, file_upload_entity) + file_obj = self._to_file_obj(file, file_extra_config) if file_obj.type not in type_file_objs: continue @@ -140,7 +140,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileUploadEntity) -> FileObj: + def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: """ transform file to file obj @@ -149,23 +149,23 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileU """ if isinstance(file, dict): transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) - return FileObj( + return FileVar( tenant_id=self.tenant_id, type=FileType.value_of(file.get('type')), transfer_method=transfer_method, url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - file_upload_entity=file_upload_entity + related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config ) else: - return FileObj( + return FileVar( id=file.id, tenant_id=self.tenant_id, type=FileType.value_of(file.type), transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, - upload_file_id=file.upload_file_id or None, - file_upload_entity=file_upload_entity + related_id=file.upload_file_id or None, + extra_config=file_extra_config ) def _check_image_remote_url(self, url): diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index b259a911d8bdc2..974fde178b31d9 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -13,6 +13,7 @@ IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) + class UploadFileParser: @classmethod def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: @@ -23,7 +24,7 @@ def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: return None if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url: - return cls.get_signed_temp_image_url(upload_file) + return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 try: @@ -36,7 +37,7 @@ def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: return f'data:{upload_file.mime_type};base64,{encoded_string}' @classmethod - def get_signed_temp_image_url(cls, upload_file) -> str: + def get_signed_temp_image_url(cls, upload_file_id) -> str: """ get signed url from upload file @@ -44,11 +45,11 @@ def get_signed_temp_image_url(cls, upload_file) -> str: :return: """ base_url = current_app.config.get('FILES_URL') - image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview' + image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}" + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" secret_key = current_app.config['SECRET_KEY'].encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 471400f09baffc..182d9504ed6752 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -45,14 +45,14 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, files = message.message_files if files: if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: - file_upload_entity = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) else: - file_upload_entity = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict) + file_extra_config = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict) - if file_upload_entity: + if file_extra_config: file_objs = message_file_parser.transform_message_files( files, - file_upload_entity + file_extra_config ) else: file_objs = [] diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 60c77e943b3322..e50ce8ab06eea9 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,7 +1,7 @@ from typing import Optional, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileObj +from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -25,7 +25,7 @@ class AdvancedPromptTransform(PromptTransform): def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], inputs: dict, query: str, - files: list[FileObj], + files: list[FileVar], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], @@ -62,7 +62,7 @@ def _get_completion_model_prompt_messages(self, prompt_template: CompletionModelPromptTemplate, inputs: dict, query: Optional[str], - files: list[FileObj], + files: list[FileVar], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], @@ -113,7 +113,7 @@ def _get_chat_model_prompt_messages(self, prompt_template: list[ChatModelMessage], inputs: dict, query: Optional[str], - files: list[FileObj], + files: list[FileVar], context: Optional[str], memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 613716c2cf3b6c..79967d9004a8dc 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -5,7 +5,7 @@ from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileObj +from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -50,7 +50,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list[FileObj], + files: list[FileVar], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> \ @@ -161,7 +161,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, inputs: dict, query: str, context: Optional[str], - files: list[FileObj], + files: list[FileVar], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -204,7 +204,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, inputs: dict, query: str, context: Optional[str], - files: list[FileObj], + files: list[FileVar], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -253,7 +253,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list[FileObj]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] for file in files: diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 1624e433566777..ceda31952e2e3b 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -21,16 +21,16 @@ class ToolFileManager: @staticmethod - def sign_file(file_id: str, extension: str) -> str: + def sign_file(tool_file_id: str, extension: str) -> str: """ sign file to get a temporary url """ base_url = current_app.config.get('FILES_URL') - file_preview_url = f'{base_url}/files/tools/{file_id}{extension}' + file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" secret_key = current_app.config['SECRET_KEY'].encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -163,23 +163,14 @@ def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None return blob, tool_file.mimetype @staticmethod - def get_file_generator_by_message_file_id(id: str) -> Union[tuple[Generator, str], None]: + def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]: """ get file binary - :param id: the id of the file + :param tool_file_id: the id of the tool file :return: the binary of the file, mime type """ - message_file: MessageFile = db.session.query(MessageFile).filter( - MessageFile.id == id, - ).first() - - # get tool file id - tool_file_id = message_file.url.split('/')[-1] - # trim extension - tool_file_id = tool_file_id.split('.')[0] - tool_file: ToolFile = db.session.query(ToolFile).filter( ToolFile.id == tool_file_id, ).first() diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index ff96bc3bac0276..4bbe9bd082149d 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Any, Optional, Union +from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable -VariableValue = Union[str, int, float, dict, list] +VariableValue = Union[str, int, float, dict, list, FileVar] class ValueType(Enum): diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index cb5a33309141dd..0d860f5dd63463 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -5,7 +5,7 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileObj +from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage @@ -51,15 +51,10 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: } # fetch files - files: list[FileObj] = self._fetch_files(node_data, variable_pool) + files: list[FileVar] = self._fetch_files(node_data, variable_pool) if files: - node_inputs['#files#'] = [{ - 'type': file.type.value, - 'transfer_method': file.transfer_method.value, - 'url': file.url, - 'upload_file_id': file.upload_file_id, - } for file in files] + node_inputs['#files#'] = [file.to_dict() for file in files] # fetch context value context = self._fetch_context(node_data, variable_pool) @@ -202,7 +197,7 @@ def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]: + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: """ Fetch files :param node_data: node data @@ -350,7 +345,7 @@ def _fetch_memory(self, node_data: LLMNodeData, def _fetch_prompt_messages(self, node_data: LLMNodeData, inputs: dict[str, str], - files: list[FileObj], + files: list[FileVar], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index d0bfd9e7973467..816a173b3456a8 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,7 +1,7 @@ from os import path from typing import cast -from core.file.file_obj import FileTransferMethod +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer @@ -58,19 +58,19 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: }, inputs=parameters ) - + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: """ Generate parameters """ return { - k.variable: - k.value if k.variable_type == 'static' else + k.variable: + k.value if k.variable_type == 'static' else variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else '' for k in node_data.tool_parameters } - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -87,7 +87,7 @@ def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str return plain_text, files - def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]: """ Extract tool response binary """ @@ -95,46 +95,50 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) for response in tool_response: if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + response.type == ToolInvokeMessage.MessageType.IMAGE: url = response.message ext = path.splitext(url)[1] mimetype = response.meta.get('mime_type', 'image/jpeg') filename = response.save_as or url.split('/')[-1] - result.append({ - 'type': 'image', - 'transfer_method': FileTransferMethod.TOOL_FILE, - 'url': url, - 'upload_file_id': None, - 'filename': filename, - 'file-ext': ext, - 'mime-type': mimetype, - }) + + # get tool file id + tool_file_id = url.split('/')[-1] + result.append(FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file_id, + filename=filename, + extension=ext, + mime_type=mimetype, + )) elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append({ - 'type': 'image', # TODO: only support image for now - 'transfer_method': FileTransferMethod.TOOL_FILE, - 'url': response.message, - 'upload_file_id': None, - 'filename': response.save_as, - 'file-ext': path.splitext(response.save_as)[1], - 'mime-type': response.meta.get('mime_type', 'application/octet-stream'), - }) + # get tool file id + tool_file_id = response.message.split('/')[-1] + result.append(FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file_id, + filename=response.save_as, + extension=path.splitext(response.save_as)[1], + mime_type=response.meta.get('mime_type', 'application/octet-stream'), + )) elif response.type == ToolInvokeMessage.MessageType.LINK: - pass # TODO: + pass # TODO: return result - + def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: """ Extract tool response text """ return ''.join([ - f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else + f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' for message in tool_response ]) - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: """ diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 747b0b86abf3ef..4a8df14c9fec32 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -59,7 +59,7 @@ def format(self, value): 'query': fields.String, 'message': fields.Raw, 'message_tokens': fields.Integer, - 'answer': fields.String, + 'answer': fields.String(attribute='re_sign_file_url_answer'), 'answer_tokens': fields.Integer, 'provider_response_latency': fields.Float, 'from_source': fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 21b2e8e9e27e55..4153db373a5506 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -68,7 +68,7 @@ 'conversation_id': fields.String, 'inputs': fields.Raw, 'query': fields.String, - 'answer': fields.String, + 'answer': fields.String(attribute='re_sign_file_url_answer'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField, diff --git a/api/models/model.py b/api/models/model.py index 5a7311a0c72ecc..84599e930b1d2a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,4 +1,5 @@ import json +import re import uuid from enum import Enum from typing import Optional @@ -610,6 +611,71 @@ class Message(db.Model): agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) workflow_run_id = db.Column(UUID) + @property + def re_sign_file_url_answer(self) -> str: + if not self.answer: + return self.answer + + pattern = r'\[!?.*?\]\((((http|https):\/\/[\w.-]+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)' + matches = re.findall(pattern, self.answer) + + if not matches: + return self.answer + + urls = [match[0] for match in matches] + + # remove duplicate urls + urls = list(set(urls)) + + if not urls: + return self.answer + + re_sign_file_url_answer = self.answer + for url in urls: + if 'files/tools' in url: + # get tool file id + tool_file_id_pattern = r'\/files\/tools\/([\.\w-]+)?\?timestamp=' + result = re.search(tool_file_id_pattern, url) + if not result: + continue + + tool_file_id = result.group(1) + + # get extension + if '.' in tool_file_id: + split_result = tool_file_id.split('.') + extension = f'.{split_result[-1]}' + if len(extension) > 10: + extension = '.bin' + tool_file_id = split_result[0] + else: + extension = '.bin' + + if not tool_file_id: + continue + + sign_url = ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=tool_file_id, + extension=extension + ) + else: + # get upload file id + upload_file_id_pattern = r'\/files\/([\w-]+)\/image-preview?\?timestamp=' + result = re.search(upload_file_id_pattern, url) + if not result: + continue + + upload_file_id = result.group(1) + + if not upload_file_id: + continue + + sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id) + + re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url) + + return re_sign_file_url_answer + @property def user_feedback(self): feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, @@ -680,7 +746,7 @@ def files(self): if message_file.transfer_method == 'local_file': upload_file = (db.session.query(UploadFile) .filter( - UploadFile.id == message_file.upload_file_id + UploadFile.id == message_file.related_id ).first()) url = UploadFileParser.get_image_data( @@ -688,6 +754,11 @@ def files(self): force_url=True ) if message_file.transfer_method == 'tool_file': + # get tool file id + tool_file_id = message_file.url.split('/')[-1] + # trim extension + tool_file_id = tool_file_id.split('.')[0] + # get extension if '.' in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' @@ -696,7 +767,7 @@ def files(self): else: extension = '.bin' # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file(file_id=message_file.id, extension=extension) + url = ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=tool_file_id, extension=extension) files.append({ 'id': message_file.id, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index b1b0b2f3159fee..af992aba85bb38 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,7 +6,7 @@ DatasetRetrieveConfigEntity, EasyUIBasedAppConfig, ExternalDataVariableEntity, - FileUploadEntity, + FileExtraConfig, ModelConfigEntity, PromptTemplateEntity, VariableEntity, @@ -416,7 +416,7 @@ def _convert_to_llm_node(self, new_app_mode: AppMode, graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileUploadEntity] = None) -> dict: + file_upload: Optional[FileExtraConfig] = None) -> dict: """ Convert to LLM Node :param new_app_mode: new app mode diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 5c08b9f168ad20..30208331aba919 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,8 +2,8 @@ import pytest -from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity -from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.app.app_config.entities import ModelConfigEntity, FileExtraConfig +from core.file.file_obj import FileVar, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform @@ -138,13 +138,13 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg model_config_mock, _, messages, inputs, context = get_chat_model_args files = [ - FileObj( + FileVar( id="file1", tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="https://example.com/image1.jpg", - file_upload_entity=FileUploadEntity( + extra_config=FileExtraConfig( image_config={ "detail": "high", }