From 101db126c8a8dbe0b8cab701d003c9d0f16b4d74 Mon Sep 17 00:00:00 2001 From: pp Date: Thu, 15 Aug 2024 00:41:12 +0800 Subject: [PATCH 01/17] fix: missed rerank_mode when convert to DatasetEntity (#7269) --- api/core/app/app_config/easy_ui_based_app/dataset/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ec17db5f06a30c..f4e6675bd44435 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -93,6 +93,7 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: reranking_model=dataset_configs.get('reranking_model'), weights=dataset_configs.get('weights'), reranking_enabled=dataset_configs.get('reranking_enabled', True), + rerank_mode=dataset_configs["reranking_mode"], ) ) From d29b32fce2291c8af0d0049003ea0bdbe306bab6 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:39:35 -0400 Subject: [PATCH 02/17] fix: typo in upstage/llm/_position.yaml (#7286) --- .../model_runtime/model_providers/upstage/llm/_position.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml index d4f03e1988f8b8..7992843dcb1d1d 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/upstage/llm/_position.yaml @@ -1 +1 @@ -- soloar-1-mini-chat +- solar-1-mini-chat From 7f67cb93ec7fb27af52ae147aec849cb6c5e352d Mon Sep 17 00:00:00 2001 From: yu5 <61819079+yukyu30@users.noreply.github.com> Date: Thu, 15 Aug 2024 10:44:02 +0900 Subject: [PATCH 03/17] fix ja-JP translation of secret values (#7279) --- web/i18n/ja-JP/workflow.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index c00578ca6aed4c..2a096ee29571bf 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -94,9 +94,9 @@ const translation = { }, export: { title: 'シークレット環境変数をエクスポートしますか?', - checkbox: 'シクレート値をエクスポート', + checkbox: 'シークレット値をエクスポート', ignore: 'DSLをエクスポート', - export: 'シクレート値を含むDSLをエクスポート', + export: 'シークレット値を含むDSLをエクスポート', }, }, changeHistory: { From d2ccd8ba538d694a0dbec3148d6645b4cbe29b43 Mon Sep 17 00:00:00 2001 From: Nam Vu Date: Thu, 15 Aug 2024 08:47:26 +0700 Subject: [PATCH 04/17] fix: #7222 docstrings (#7276) --- api/core/workflow/workflow_engine_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index f299f84efb3ee5..92737ab0c60ce1 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -103,10 +103,10 @@ def run_workflow( :param workflow: Workflow instance :param user_id: user id :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files + :param invoke_from: invoke from :param callbacks: workflow callbacks :param call_depth: call depth + :param variable_pool: variable pool """ # fetch workflow graph graph = workflow.graph_dict From 681ec6f845436bec8caef6eab88a04fa47bb0cd5 Mon Sep 17 00:00:00 2001 From: Hanqing Zhao Date: Thu, 15 Aug 2024 09:47:51 +0800 Subject: [PATCH 05/17] Add jp translation for variable aggregator (#7277) --- web/i18n/ja-JP/workflow.ts | 44 +++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index 2a096ee29571bf..8f506bcb46bd3e 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -99,6 +99,33 @@ const translation = { export: 'シークレット値を含むDSLをエクスポート', }, }, + chatVariable: { + panelTitle: '会話変数', + panelDescription: '会話変数は、LLMが記憶すべき対話情報を保存するために使用されます。この情報には、対話の履歴、アップロードされたファイル、ユーザーの好みなどが含まれます。読み書きが可能です。', + docLink: '詳しくはドキュメントをご覧ください。', + button: '変数を追加', + modal: { + title: '会話変数を追加', + editTitle: '会話変数を編集', + name: '名前', + namePlaceholder: '変数名前', + type: 'タイプ', + value: 'デフォルト値', + valuePlaceholder: 'デフォルト値、設定しない場合は空白にしでください', + description: '説明', + descriptionPlaceholder: '変数の説明', + editInJSON: 'JSONで編集する', + oneByOne: '次々に追加する', + editInForm: 'フォームで編集', + arrayValue: '値', + addArrayValue: '値を追加', + objectKey: 'キー', + objectType: 'タイプ', + objectValue: 'デフォルト値', + }, + storedContent: '保存されたコンテンツ', + updatedAt: '更新日は', + }, changeHistory: { title: '変更履歴', placeholder: 'まだ何も変更していません', @@ -149,6 +176,7 @@ const translation = { tabs: { 'searchBlock': 'ブロックを検索', 'blocks': 'ブロック', + 'searchTool': '検索ツール', 'tools': 'ツール', 'allTool': 'すべて', 'workflowTool': 'ワークフロー', @@ -171,8 +199,9 @@ const translation = { 'code': 'コード', 'template-transform': 'テンプレート', 'http-request': 'HTTPリクエスト', - 'variable-assigner': '変数代入', + 'variable-assigner': '変数代入器', 'variable-aggregator': '変数集約器', + 'assigner': '変数代入', 'iteration-start': 'イテレーション開始', 'iteration': 'イテレーション', 'parameter-extractor': 'パラメーター抽出', @@ -189,6 +218,7 @@ const translation = { 'template-transform': 'Jinjaテンプレート構文を使用してデータを文字列に変換します', 'http-request': 'HTTPプロトコル経由でサーバーリクエストを送信できます', 'variable-assigner': '複数のブランチの変数を1つの変数に集約し、下流のノードに対して統一された設定を行います。', + 'assigner': '変数代入ノードは、書き込み可能な変数(例えば、会話変数)に値を割り当てるために使用されます。', 'variable-aggregator': '複数のブランチの変数を1つの変数に集約し、下流のノードに対して統一された設定を行います。', 'iteration': 'リストオブジェクトに対して複数のステップを実行し、すべての結果が出力されるまで繰り返します。', 'parameter-extractor': '自然言語からツールの呼び出しやHTTPリクエストのための構造化されたパラメーターを抽出するためにLLMを使用します。', @@ -215,6 +245,7 @@ const translation = { checklistResolved: 'すべての問題が解決されました', organizeBlocks: 'ブロックを整理', change: '変更', + optional: '(オプション)', }, nodes: { common: { @@ -406,6 +437,17 @@ const translation = { }, setAssignVariable: '代入された変数を設定', }, + assigner: { + 'assignedVariable': '代入された変数', + 'writeMode': '書き込みモード', + 'writeModeTip': '代入された変数が配列の場合, 末尾に追記モードを追加する。', + 'over-write': '上書き', + 'append': '追記', + 'plus': 'プラス', + 'clear': 'クリア', + 'setVariable': '変数を設定する', + 'variable': '変数', + }, tool: { toAuthorize: '承認するには', inputVars: '入力変数', From 8f5d8397f9b5bdf431783d58f08bc23e1832a7fd Mon Sep 17 00:00:00 2001 From: wellCh4n Date: Thu, 15 Aug 2024 10:31:34 +0800 Subject: [PATCH 06/17] fix: can not input param value in tool test modal (#7281) --- .../edit-custom-collection-modal/index.tsx | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/web/app/components/tools/edit-custom-collection-modal/index.tsx b/web/app/components/tools/edit-custom-collection-modal/index.tsx index 5fcf6fb0248705..e84e15da17fbb0 100644 --- a/web/app/components/tools/edit-custom-collection-modal/index.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/index.tsx @@ -327,36 +327,36 @@ const EditCustomCollectionModal: FC = ({ + {showEmojiPicker && { + setEmoji({ content: icon, background: icon_background }) + setShowEmojiPicker(false) + }} + onClose={() => { + setShowEmojiPicker(false) + }} + />} + {credentialsModalShow && ( + setCredentialsModalShow(false)} + />) + } + {isShowTestApi && ( + setIsShowTestApi(false)} + /> + )} } isShowMask={true} clickOutsideNotOpen={true} /> - {showEmojiPicker && { - setEmoji({ content: icon, background: icon_background }) - setShowEmojiPicker(false) - }} - onClose={() => { - setShowEmojiPicker(false) - }} - />} - {credentialsModalShow && ( - setCredentialsModalShow(false)} - />) - } - {isShowTestApi && ( - setIsShowTestApi(false)} - /> - )} ) From 32dc9635569c1962501455f1febe58fa51b39d8e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 15 Aug 2024 10:53:05 +0800 Subject: [PATCH 07/17] feat(api/workflow): Add `Conversation.dialogue_count` (#7275) --- api/contexts/__init__.py | 6 +- .../app/apps/advanced_chat/app_generator.py | 108 ++++++++++++++---- api/core/app/apps/advanced_chat/app_runner.py | 53 +-------- .../advanced_chat/generate_task_pipeline.py | 14 ++- .../app/apps/message_based_app_generator.py | 5 +- api/core/app/apps/workflow/app_runner.py | 2 +- .../apps/workflow/generate_task_pipeline.py | 6 +- api/core/app/segments/__init__.py | 6 - api/core/app/segments/factory.py | 12 -- api/core/app/segments/segments.py | 13 --- api/core/app/segments/types.py | 2 - api/core/app/segments/variables.py | 9 -- .../workflow_cycle_state_manager.py | 4 +- api/core/workflow/entities/node_entities.py | 28 +---- api/core/workflow/entities/variable_pool.py | 2 +- api/core/workflow/enums.py | 25 ++++ api/core/workflow/nodes/llm/llm_node.py | 11 +- api/core/workflow/nodes/tool/tool_node.py | 11 +- api/core/workflow/workflow_engine_manager.py | 5 +- ...7ff0dc_add_conversations_dialogue_count.py | 33 ++++++ api/models/__init__.py | 6 +- api/models/model.py | 5 +- .../workflow/nodes/test_llm.py | 4 +- .../nodes/test_parameter_extractor.py | 6 +- .../core/app/segments/test_factory.py | 80 ------------- .../core/app/segments/test_segment.py | 2 +- .../core/workflow/nodes/test_answer.py | 2 +- .../core/workflow/nodes/test_if_else.py | 2 +- .../workflow/nodes/test_variable_assigner.py | 2 +- 29 files changed, 205 insertions(+), 259 deletions(-) create mode 100644 api/core/workflow/enums.py create mode 100644 api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 306fac3a931298..b6b18f5c5be57d 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,3 +1,7 @@ from contextvars import ContextVar -tenant_id: ContextVar[str] = ContextVar('tenant_id') \ No newline at end of file +from core.workflow.entities.variable_pool import VariablePool + +tenant_id: ContextVar[str] = ContextVar('tenant_id') + +workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool') diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 0cde6599926364..351eb05d8ad41c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -8,6 +8,8 @@ from flask import Flask, current_app from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -18,15 +20,20 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models.workflow import ConversationVariable, Workflow logger = logging.getLogger(__name__) @@ -120,7 +127,7 @@ def generate( conversation=conversation, stream=stream ) - + def single_iteration_generate(self, app_model: App, workflow: Workflow, node_id: str, @@ -140,10 +147,10 @@ def single_iteration_generate(self, app_model: App, """ if not node_id: raise ValueError('node_id is required') - + if args.get('inputs') is None: raise ValueError('inputs is required') - + extras = { "auto_generate_conversation_name": False } @@ -209,7 +216,7 @@ def _generate(self, *, # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - db.session.refresh(conversation) + # db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -221,15 +228,69 @@ def _generate(self, *, message_id=message.id ) + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=conversation.app_id, conversation_id=conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + conversation.dialogue_count += 1 + + conversation_id = conversation.id + conversation_dialogue_count = conversation.dialogue_count + db.session.commit() + db.session.refresh(conversation) + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + user_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = application_generate_entity.user_id + + # Create a variable pool. + system_inputs = { + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION_ID: conversation_id, + SystemVariable.USER_ID: user_id, + SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, + } + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + contexts.workflow_variable_pool.set(variable_pool) + # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'conversation_id': conversation.id, 'message_id': message.id, - 'user': user, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), }) worker_thread.start() @@ -242,7 +303,7 @@ def _generate(self, *, conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) return AdvancedChatAppGenerateResponseConverter.convert( @@ -253,9 +314,7 @@ def _generate(self, *, def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, - conversation_id: str, message_id: str, - user: Account, context: contextvars.Context) -> None: """ Generate worker in a new thread. @@ -282,8 +341,7 @@ def _generate_worker(self, flask_app: Flask, user_id=application_generate_entity.user_id ) else: - # get conversation and message - conversation = self._get_conversation(conversation_id) + # get message message = self._get_message(message_id) # chatbot app @@ -291,7 +349,6 @@ def _generate_worker(self, flask_app: Flask, runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - conversation=conversation, message=message ) except GenerateTaskStoppedException: @@ -314,14 +371,17 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.close() - def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False) \ - -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def _handle_advanced_chat_response( + self, + *, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -341,7 +401,7 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 47c53531f6aa19..5dc03979cf3b4b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -4,9 +4,6 @@ from collections.abc import Mapping from typing import Any, Optional, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -19,13 +16,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models import App, Message, Workflow logger = logging.getLogger(__name__) @@ -39,7 +33,6 @@ def run( self, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, - conversation: Conversation, message: Message, ) -> None: """ @@ -63,15 +56,6 @@ def run( inputs = application_generate_entity.inputs query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id # moderation if self.handle_input_moderation( @@ -103,38 +87,6 @@ def run( if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - session.commit() - # Convert database entities to variables - conversation_variables = [item.to_variable() for item in conversation_variables] - - # Create a variable pool. - system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( @@ -146,7 +98,6 @@ def run( invoke_from=application_generate_entity.invoke_from, callbacks=workflow_callbacks, call_depth=application_generate_entity.call_depth, - variable_pool=variable_pool, ) def single_iteration_run( @@ -155,7 +106,7 @@ def single_iteration_run( """ Single iteration run """ - app_record: App = db.session.query(App).filter(App.id == app_id).first() + app_record = db.session.query(App).filter(App.id == app_id).first() if not app_record: raise ValueError('App not found') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 91a43ed4493027..f8efcb59606d08 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ from collections.abc import Generator from typing import Any, Optional, Union, cast +import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -47,7 +48,8 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] + # Deprecated _workflow_system_variables: dict[SystemVariable, Any] _iteration_nested_relations: dict[str, list[str]] @@ -81,7 +84,7 @@ def __init__( conversation: Conversation, message: Message, user: Union[Account, EndUser], - stream: bool + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -103,11 +106,12 @@ def __init__( self._workflow = workflow self._conversation = conversation self._message = message + # Deprecated self._workflow_system_variables = { SystemVariable.QUERY: message.query, SystemVariable.FILES: application_generate_entity.files, SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id + SystemVariable.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( @@ -613,7 +617,9 @@ def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: if route_chunk_node_id == 'sys': # system variable - value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) + value = contexts.workflow_variable_pool.get().get(value_selector) + if value: + value = value.text elif route_chunk_node_id in self._iteration_nested_relations: # it's a iteration variable if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c5cd6864020b33..12f69f1528e241 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -258,7 +258,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat return introduction - def _get_conversation(self, conversation_id: str) -> Conversation: + def _get_conversation(self, conversation_id: str): """ Get conversation by conversation id :param conversation_id: conversation id @@ -270,6 +270,9 @@ def _get_conversation(self, conversation_id: str) -> Conversation: .first() ) + if not conversation: + raise ConversationNotExistsError() + return conversation def _get_message(self, message_id: str) -> Message: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 17a99cf1c5fd63..994919391e7ed5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,8 +11,8 @@ WorkflowAppGenerateEntity, ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2b4362150fc7e7..5022eb0438d13b 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -42,7 +42,8 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType +from core.workflow.enums import SystemVariable from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -519,7 +520,7 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: """ nodes = graph.get('nodes') - iteration_ids = [node.get('id') for node in nodes + iteration_ids = [node.get('id') for node in nodes if node.get('data', {}).get('type') in [ NodeType.ITERATION.value, NodeType.LOOP.value, @@ -530,4 +531,3 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id ] for iteration_id in iteration_ids } - \ No newline at end of file diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 174e241261fe86..7de06dfb9639fd 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -2,7 +2,6 @@ from .segments import ( ArrayAnySegment, ArraySegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -13,11 +12,9 @@ from .types import SegmentType from .variables import ( ArrayAnyVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, NoneVariable, @@ -32,7 +29,6 @@ 'FloatVariable', 'ObjectVariable', 'SecretVariable', - 'FileVariable', 'StringVariable', 'ArrayAnyVariable', 'Variable', @@ -45,11 +41,9 @@ 'FloatSegment', 'ObjectSegment', 'ArrayAnySegment', - 'FileSegment', 'StringSegment', 'ArrayStringVariable', 'ArrayNumberVariable', 'ArrayObjectVariable', - 'ArrayFileVariable', 'ArraySegment', ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index 91ff1fdb3de6eb..e6e9ce97747ce1 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -2,12 +2,10 @@ from typing import Any from configs import dify_config -from core.file.file_obj import FileVar from .exc import VariableError from .segments import ( ArrayAnySegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -17,11 +15,9 @@ ) from .types import SegmentType from .variables import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f'invalid number value {value}') - case SegmentType.FILE: - result = FileVariable.model_validate(mapping) case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): @@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_FILE if isinstance(value, list): - mapping = dict(mapping) - mapping['value'] = [{'value': v} for v in value] - result = ArrayFileVariable.model_validate(mapping) case _: raise VariableError(f'not supported value type {value_type}') if result.size > dify_config.MAX_VARIABLE_SIZE: @@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, list): return ArrayAnySegment(value=value) - if isinstance(value, FileVar): - return FileSegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 7653e1085f881a..321bc0ad020419 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -5,8 +5,6 @@ from pydantic import BaseModel, ConfigDict, field_validator -from core.file.file_obj import FileVar - from .types import SegmentType @@ -78,14 +76,7 @@ class IntegerSegment(Segment): value: int -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - # TODO: embed FileVar in this model. - value: FileVar - @property - def markdown(self) -> str: - return self.value.to_markdown() class ObjectSegment(Segment): @@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[FileSegment] diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index a371058ef52bac..cdd2b0b4b09191 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -10,8 +10,6 @@ class SegmentType(str, Enum): ARRAY_STRING = 'array[string]' ARRAY_NUMBER = 'array[number]' ARRAY_OBJECT = 'array[object]' - ARRAY_FILE = 'array[file]' OBJECT = 'object' - FILE = 'file' GROUP = 'group' diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index ac26e165425c3a..8fef707fcf298b 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -4,11 +4,9 @@ from .segments import ( ArrayAnySegment, - ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, - FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable): pass -class FileVariable(FileSegment, Variable): - pass - - class ObjectVariable(ObjectSegment, Variable): pass @@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): - pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 545f31fddfaedb..8baa8ba09e4b00 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariable from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file + _workflow_system_variables: dict[SystemVariable, Any] diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 0978b09b943694..025453567bfc1b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -4,13 +4,14 @@ from pydantic import BaseModel -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class NodeType(Enum): """ Node Types. """ + START = 'start' END = 'end' ANSWER = 'answer' @@ -44,33 +45,11 @@ def value_of(cls, value: str) -> 'NodeType': raise ValueError(f'invalid node type value {value}') -class SystemVariable(Enum): - """ - System Variables. - """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - - @classmethod - def value_of(cls, value: str) -> 'SystemVariable': - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') - - class NodeRunMetadataKey(Enum): """ Node Run Metadata Key. """ + TOTAL_TOKENS = 'total_tokens' TOTAL_PRICE = 'total_price' CURRENCY = 'currency' @@ -83,6 +62,7 @@ class NodeRunResult(BaseModel): """ Node Run Result. """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index a96a26f794db8d..9fe3356faa2ef5 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,7 +6,7 @@ from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.entities.node_entities import SystemVariable +from core.workflow.enums import SystemVariable VariableValue = Union[str, int, float, dict, list, FileVar] diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py new file mode 100644 index 00000000000000..4757cf32f88988 --- /dev/null +++ b/api/core/workflow/enums.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class SystemVariable(str, Enum): + """ + System Variables. + """ + QUERY = 'query' + FILES = 'files' + CONVERSATION_ID = 'conversation_id' + USER_ID = 'user_id' + DIALOGUE_COUNT = 'dialogue_count' + + @classmethod + def value_of(cls, value: str): + """ + Get value of given system variable. + + :param value: system variable value + :return: system variable + """ + for system_variable in cls: + if system_variable.value == value: + return system_variable + raise ValueError(f'invalid system variable value {value}') diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 4431259a57543b..97b64d4b052592 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -23,8 +23,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -201,8 +202,8 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage usage = LLMUsage.empty_usage() return full_text, usage - - def _transform_chat_messages(self, + + def _transform_chat_messages(self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ @@ -249,13 +250,13 @@ def parse_dict(d: dict) -> str: # check if it's a context structure if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: return d['content'] - + # else, parse the dict try: return json.dumps(d, ensure_ascii=False) except Exception: return str(d) - + if isinstance(value, str): value = value elif isinstance(value, list): diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 87bfa5beae880b..554e3b6074ed58 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,19 +2,20 @@ from os import path from typing import Any, cast -from core.app.segments import parser +from core.app.segments import ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models.workflow import WorkflowNodeExecutionStatus +from models import WorkflowNodeExecutionStatus class ToolNode(BaseNode): @@ -140,9 +141,9 @@ def _generate_parameters( return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - # FIXME: ensure this is a ArrayVariable contains FileVariable. variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - return [file_var.value for file_var in variable.value] if variable else [] + assert isinstance(variable, ArrayAnyVariable) + return list(variable.value) if variable else [] def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 92737ab0c60ce1..3157eedfee5238 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, cast +import contexts from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom @@ -97,7 +98,7 @@ def run_workflow( invoke_from: InvokeFrom, callbacks: Sequence[WorkflowCallback], call_depth: int = 0, - variable_pool: VariablePool, + variable_pool: VariablePool | None = None, ) -> None: """ :param workflow: Workflow instance @@ -128,6 +129,8 @@ def run_workflow( raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) # init workflow run state + if not variable_pool: + variable_pool = contexts.workflow_variable_pool.get() workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), diff --git a/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py new file mode 100644 index 00000000000000..eba78e2e77d5d8 --- /dev/null +++ b/api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py @@ -0,0 +1,33 @@ +"""add conversations.dialogue_count + +Revision ID: 8782057ff0dc +Revises: 63a83fcf12ba +Create Date: 2024-08-14 13:54:25.161324 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8782057ff0dc' +down_revision = '63a83fcf12ba' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_column('dialogue_count') + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index f8313568416292..4012611471c337 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,10 +1,10 @@ from enum import Enum -from .model import AppMode +from .model import App, AppMode, Message from .types import StringUUID -from .workflow import ConversationVariable, WorkflowNodeExecutionStatus +from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus -__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus'] +__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] class CreatedByRole(Enum): diff --git a/api/models/model.py b/api/models/model.py index 9909b10dc0952d..5426d3bc83e020 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,6 +7,7 @@ from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.file.tool_file_parser import ToolFileParser @@ -512,12 +513,12 @@ class Conversation(db.Model): from_account_id = db.Column(StringUUID) read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) + dialogue_count: Mapped[int] = mapped_column(default=0) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', - passive_deletes="all") + message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index ac704e4eaf54df..4686ce06752ed5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -10,8 +10,8 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) \ No newline at end of file + assert 'what\'s the weather today?' in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 312ad47026beb5..adf5ffe3cadf77 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -12,8 +12,8 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -363,7 +363,7 @@ def test_extract_json_response(): { "location": "kawaii" } - hello world. + hello world. """) assert result['location'] == 'kawaii' @@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): assert latest_role != prompt.get('role') if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') \ No newline at end of file + latest_role = prompt.get('role') diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index a8429b9c1b1ccb..afd0fa50b590f8 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -3,12 +3,9 @@ import pytest from core.app.segments import ( - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, - FileSegment, - FileVariable, FloatVariable, IntegerVariable, ObjectSegment, @@ -149,83 +146,6 @@ def test_array_object_variable(): assert isinstance(variable.value[1]['key2'], int) -def test_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'file', - 'name': 'test_file', - 'description': 'Description of the variable.', - 'value': { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, FileVariable) - - -def test_array_file_variable(): - mapping = { - 'id': str(uuid4()), - 'value_type': 'array[file]', - 'name': 'test_array_file', - 'description': 'Description of the variable.', - 'value': [ - { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - { - 'id': str(uuid4()), - 'tenant_id': 'tenant_id', - 'type': 'image', - 'transfer_method': 'local_file', - 'url': 'url', - 'related_id': 'related_id', - 'extra_config': { - 'image_config': { - 'width': 100, - 'height': 100, - }, - }, - 'filename': 'filename', - 'extension': 'extension', - 'mime_type': 'mime_type', - }, - ], - } - variable = factory.build_variable_from_mapping(mapping) - assert isinstance(variable, ArrayFileVariable) - assert isinstance(variable.value[0], FileSegment) - assert isinstance(variable.value[1], FileSegment) - - def test_variable_cannot_large_than_5_kb(): with pytest.raises(VariableError): factory.build_variable_from_mapping( diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 414404b7d0362a..7e3e69ffbfc45d 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,7 +1,7 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable def test_segment_group_to_text(): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 3a32829e373c28..4617b6a42f8ec2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 4662c5ff2b26d8..d21b7785c4f4a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 8706ba05ceaee7..0b37d06fc069bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -3,8 +3,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable -from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode From 5aa373dc04aa7414556bf4282b73d342cfbe83b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Thu, 15 Aug 2024 11:19:10 +0800 Subject: [PATCH 08/17] feat: add chatgpt-4o-latest (#7289) --- .../model_providers/openai/llm/_position.yaml | 1 + .../openai/llm/chatgpt-4o-latest.yaml | 44 +++++++++++++++++++ .../model_providers/openai/llm/llm.py | 9 ++-- .../model-provider-page/model-icon/index.tsx | 2 +- 4 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 21661b9a2b8aef..ac7313aaa1bf0b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -2,6 +2,7 @@ - gpt-4o - gpt-4o-2024-05-13 - gpt-4o-2024-08-06 +- chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 - gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml new file mode 100644 index 00000000000000..98e236650c9e73 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -0,0 +1,44 @@ +model: chatgpt-4o-latest +label: + zh_Hans: chatgpt-4o-latest + en_US: chatgpt-4o-latest +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index aae2729bdfb042..556602390b3b7a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -922,11 +922,14 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" if model.startswith('ft:'): model = model.split(':')[1] + # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. + if model == "chatgpt-4o-latest": + model = "gpt-4o" + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -946,7 +949,7 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], raise NotImplementedError( f"get_num_tokens_from_messages() is not presently implemented " f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "See https://platform.openai.com/docs/advanced-usage/managing-tokens for " "information on how messages are converted to tokens." ) num_tokens = 0 diff --git a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx index 347572c755ae1a..a22ec16c252288 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx @@ -19,7 +19,7 @@ const ModelIcon: FC = ({ }) => { const language = useLanguage() - if (provider?.provider === 'openai' && modelName?.startsWith('gpt-4')) + if (provider?.provider === 'openai' && (modelName?.startsWith('gpt-4') || modelName?.includes('4o'))) return if (provider?.icon_small) { From 6ff7fd80a1f0d6940534d2b52baf86b3b54a775f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Thu, 15 Aug 2024 11:29:19 +0800 Subject: [PATCH 09/17] feat: support OPENAI json_schema (#7258) --- api/core/model_runtime/entities/defaults.py | 16 +++++++++++++--- .../model_runtime/entities/model_entities.py | 2 ++ .../openai/llm/gpt-4o-2024-08-06.yaml | 3 +++ .../openai/llm/gpt-4o-mini.yaml | 3 +++ .../model_providers/openai/llm/llm.py | 18 ++++++++++++------ .../model-parameter-modal/parameter-item.tsx | 12 +++++++++++- 6 files changed, 44 insertions(+), 10 deletions(-) diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 87fe4f681ce5c7..d2076bf74a3cde 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { @@ -94,5 +93,16 @@ }, 'required': False, 'options': ['JSON', 'XML'], - } -} \ No newline at end of file + }, + DefaultParameterName.JSON_SCHEMA: { + 'label': { + 'en_US': 'JSON Schema', + }, + 'type': 'text', + 'help': { + 'en_US': 'Set a response json schema will ensure LLM to adhere it.', + 'zh_Hans': '设置返回的json schema,llm将按照它返回', + }, + 'required': False, + }, +} diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3d471787bbef8e..c257ce63d27926 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -95,6 +95,7 @@ class DefaultParameterName(Enum): FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" RESPONSE_FORMAT = "response_format" + JSON_SCHEMA = "json_schema" @classmethod def value_of(cls, value: Any) -> 'DefaultParameterName': @@ -118,6 +119,7 @@ class ParameterType(Enum): INT = "int" STRING = "string" BOOLEAN = "boolean" + TEXT = "text" class ModelPropertyKey(Enum): diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml index cf2de0f73a0b84..7e430c51a710fc 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '2.50' output: '10.00' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index b97fbf8aabcae4..23dcf85085e123 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 556602390b3b7a..06135c958463e8 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,3 +1,4 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union, cast @@ -544,13 +545,18 @@ def _chat_generate(self, model: str, credentials: dict, response_format = model_parameters.get("response_format") if response_format: - if response_format == "json_object": - response_format = {"type": "json_object"} + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not currect json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: - response_format = {"type": "text"} - - model_parameters["response_format"] = response_format - + model_parameters["response_format"] = {"type": response_format} extra_model_kwargs = {} diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 57ea4bdd118fed..eced2a8082bb86 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -100,7 +100,7 @@ const ParameterItem: FC = ({ handleInputChange(v === 1) } - const handleStringInputChange = (e: React.ChangeEvent) => { + const handleStringInputChange = (e: React.ChangeEvent) => { handleInputChange(e.target.value) } @@ -190,6 +190,16 @@ const ParameterItem: FC = ({ ) } + if (parameterRule.type === 'text') { + return ( +