diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index a7fd0164d86f73..3a8949f960b219 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,10 +1,6 @@ -import json import logging -from collections.abc import Generator -from typing import Union import flask_login -from flask import Response, stream_with_context from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -25,10 +21,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs import helper from libs.helper import uuid_value from libs.login import login_required from models.model import AppMode -from services.completion_service import CompletionService +from services.app_generate_service import AppGenerateService # define completion message api for user @@ -54,7 +51,7 @@ def post(self, app_model): account = flask_login.current_user try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=account, args=args, @@ -62,7 +59,7 @@ def post(self, app_model): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -120,7 +117,7 @@ def post(self, app_model): account = flask_login.current_user try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=account, args=args, @@ -128,7 +125,7 @@ def post(self, app_model): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -151,17 +148,6 @@ def post(self, app_model): raise InternalServerError() -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - class ChatMessageStopApi(Resource): @setup_required @login_required diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 56d2e718e7d4d1..9a8de8ae3dd958 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,9 +1,5 @@ -import json import logging -from collections.abc import Generator -from typing import Union -from flask import Response, stream_with_context from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range @@ -179,17 +175,6 @@ def get(self, app_model): return {'count': count} -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - class MessageSuggestedQuestionApi(Resource): @setup_required @login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index d5967dd5ed31ed..4994e464ba06f9 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,9 +1,6 @@ import json import logging -from collections.abc import Generator -from typing import Union -from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -13,12 +10,15 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields from fields.workflow_run_fields import workflow_run_node_execution_fields +from libs import helper from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode +from services.app_generate_service import AppGenerateService from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -87,16 +87,16 @@ def post(self, app_model: App): parser.add_argument('conversation_id', type=uuid_value, location='json') args = parser.parse_args() - workflow_service = WorkflowService() try: - response = workflow_service.run_advanced_chat_draft_workflow( + response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER + invoke_from=InvokeFrom.DEBUGGER, + streaming=True ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -121,17 +121,16 @@ def post(self, app_model: App): parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() - workflow_service = WorkflowService() - try: - response = workflow_service.run_draft_workflow( + response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER + invoke_from=InvokeFrom.DEBUGGER, + streaming=True ) - return compact_response(response) + return helper.compact_generate_response(response) except ValueError as e: raise e except Exception as e: @@ -148,12 +147,7 @@ def post(self, app_model: App, task_id: str): """ Stop workflow task """ - workflow_service = WorkflowService() - workflow_service.stop_workflow_task( - task_id=task_id, - user=current_user, - invoke_from=InvokeFrom.DEBUGGER - ) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return { "result": "success" @@ -283,16 +277,6 @@ def post(self, app_model: App): return workflow -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index b8a5be0df0768a..bff494dccb9756 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,10 +1,6 @@ -import json import logging -from collections.abc import Generator from datetime import datetime -from typing import Union -from flask import Response, stream_with_context from flask_login import current_user from flask_restful import reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -26,8 +22,9 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from libs import helper from libs.helper import uuid_value -from services.completion_service import CompletionService +from services.app_generate_service import AppGenerateService # define completion api for user @@ -53,7 +50,7 @@ def post(self, installed_app): db.session.commit() try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, @@ -61,7 +58,7 @@ def post(self, installed_app): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -117,7 +114,7 @@ def post(self, installed_app): db.session.commit() try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, @@ -125,7 +122,7 @@ def post(self, installed_app): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -159,17 +156,6 @@ def post(self, installed_app, task_id): return {'result': 'success'}, 200 -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index fdb0eae24f00d4..ef051233b03031 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,9 +1,5 @@ -import json import logging -from collections.abc import Generator -from typing import Union -from flask import Response, stream_with_context from flask_login import current_user from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range @@ -28,8 +24,9 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields +from libs import helper from libs.helper import uuid_value -from services.completion_service import CompletionService +from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -91,14 +88,14 @@ def get(self, installed_app, message_id): streaming = args['response_mode'] == 'streaming' try: - response = CompletionService.generate_more_like_this( + response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except MessageNotExistsError: raise NotFound("Message Not Exists.") except MoreLikeThisDisabledError: @@ -118,17 +115,6 @@ def get(self, installed_app, message_id): raise InternalServerError() -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 410fb5bffd8e4e..3f284d2326d4f8 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,9 +1,5 @@ -import json import logging -from collections.abc import Generator -from typing import Union -from flask import Response, stream_with_context from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -23,9 +19,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs import helper from libs.helper import uuid_value from models.model import App, EndUser -from services.completion_service import CompletionService +from services.app_generate_service import AppGenerateService class CompletionApi(Resource): @@ -48,7 +45,7 @@ def post(self, app_model: App, end_user: EndUser): args['auto_generate_name'] = False try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, @@ -56,7 +53,7 @@ def post(self, app_model: App, end_user: EndUser): streaming=streaming, ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -110,7 +107,7 @@ def post(self, app_model: App, end_user: EndUser): streaming = args['response_mode'] == 'streaming' try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, @@ -118,7 +115,7 @@ def post(self, app_model: App, end_user: EndUser): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -152,17 +149,6 @@ def post(self, app_model: App, end_user: EndUser, task_id): return {'result': 'success'}, 200 -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index ed1378e7e3d85f..452ce8709e9673 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,9 +1,5 @@ -import json import logging -from collections.abc import Generator -from typing import Union -from flask import Response, stream_with_context from flask_restful import reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -24,8 +20,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs import helper from libs.helper import uuid_value -from services.completion_service import CompletionService +from services.app_generate_service import AppGenerateService # define completion api for user @@ -48,7 +45,7 @@ def post(self, app_model, end_user): args['auto_generate_name'] = False try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, @@ -56,7 +53,7 @@ def post(self, app_model, end_user): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -108,7 +105,7 @@ def post(self, app_model, end_user): args['auto_generate_name'] = False try: - response = CompletionService.completion( + response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, @@ -116,7 +113,7 @@ def post(self, app_model, end_user): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -149,17 +146,6 @@ def post(self, app_model, end_user, task_id): return {'result': 'success'}, 200 -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 1acb92dbf1eda5..c4e49118d8e91f 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,9 +1,5 @@ -import json import logging -from collections.abc import Generator -from typing import Union -from flask import Response, stream_with_context from flask_restful import fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound @@ -26,8 +22,9 @@ from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields +from libs import helper from libs.helper import TimestampField, uuid_value -from services.completion_service import CompletionService +from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -127,7 +124,7 @@ def get(self, app_model, end_user, message_id): streaming = args['response_mode'] == 'streaming' try: - response = CompletionService.generate_more_like_this( + response = AppGenerateService.generate_more_like_this( app_model=app_model, user=end_user, message_id=message_id, @@ -135,7 +132,7 @@ def get(self, app_model, end_user, message_id): streaming=streaming ) - return compact_response(response) + return helper.compact_generate_response(response) except MessageNotExistsError: raise NotFound("Message Not Exists.") except MoreLikeThisDisabledError: @@ -155,17 +152,6 @@ def get(self, app_model, end_user, message_id): raise InternalServerError() -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): if app_model.mode != 'chat': diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 1a33a3230bd773..30b583ab06897a 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -10,11 +10,13 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db @@ -32,7 +34,7 @@ def generate(self, app_model: App, args: dict, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator]: + -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -123,7 +125,7 @@ def generate(self, app_model: App, worker_thread.start() # return response or stream generator - return self._handle_advanced_chat_response( + response = self._handle_advanced_chat_response( application_generate_entity=application_generate_entity, workflow=workflow, queue_manager=queue_manager, @@ -133,6 +135,11 @@ def generate(self, app_model: App, stream=stream ) + return AdvancedChatAppGenerateResponseConverter.convert( + response=response, + invoke_from=invoke_from + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, @@ -185,7 +192,8 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh conversation: Conversation, message: Message, user: Union[Account, EndUser], - stream: bool = False) -> Union[dict, Generator]: + stream: bool = False) \ + -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py new file mode 100644 index 00000000000000..d211db951149ce --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -0,0 +1,107 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = ChatbotAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + 'event': 'message', + 'task_id': blocking_response.task_id, + 'id': blocking_response.data.id, + 'message_id': blocking_response.data.message_id, + 'conversation_id': blocking_response.data.conversation_id, + 'mode': blocking_response.data.mode, + 'answer': blocking_response.data.answer, + 'metadata': blocking_response.data.metadata, + 'created_at': blocking_response.data.created_at + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get('metadata', {}) + response['metadata'] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'conversation_id': chunk.conversation_id, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'conversation_id': chunk.conversation_id, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + sub_stream_response_dict = sub_stream_response.to_dict() + if isinstance(sub_stream_response, MessageEndStreamResponse): + metadata = sub_stream_response_dict.get('metadata', {}) + sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + + response_chunk.update(sub_stream_response_dict) + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ca4b143027c512..77801e8dc34af0 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,16 +1,11 @@ -import json import logging import time from collections.abc import Generator -from typing import Optional, Union, cast - -from pydantic import BaseModel, Extra +from typing import Any, Optional, Union, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -29,84 +24,42 @@ QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.app.entities.task_entities import ( + AdvancedChatTaskState, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + MessageEndStreamResponse, + StreamGenerateRoute, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.moderation.output_moderation import ModerationRule, OutputModeration -from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeType, SystemVariable from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk +from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account -from models.model import Conversation, EndUser, Message, MessageFile +from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowRun, WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) -from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) -class StreamGenerateRoute(BaseModel): - """ - StreamGenerateRoute entity - """ - answer_node_id: str - generate_route: list[GenerateRouteChunk] - current_route_position: int = 0 - - -class TaskState(BaseModel): - """ - TaskState entity - """ - - class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - node_type: NodeType - start_at: float - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - answer: str = "" - metadata: dict = {} - usage: LLMUsage - - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - current_stream_generate_state: Optional[StreamGenerateRoute] = None - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - -class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): +class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _task_state: AdvancedChatTaskState + _application_generate_entity: AdvancedChatAppGenerateEntity + _workflow: Workflow + _user: Union[Account, EndUser] + _workflow_system_variables: dict[SystemVariable, Any] def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, @@ -116,7 +69,7 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, user: Union[Account, EndUser], stream: bool) -> None: """ - Initialize GenerateTaskPipeline. + Initialize AdvancedChatAppGenerateTaskPipeline. :param application_generate_entity: application generate entity :param workflow: workflow :param queue_manager: queue manager @@ -125,25 +78,27 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, :param user: user :param stream: stream """ - self._application_generate_entity = application_generate_entity + super().__init__(application_generate_entity, queue_manager, user, stream) + self._workflow = workflow - self._queue_manager = queue_manager self._conversation = conversation self._message = message - self._user = user - self._task_state = TaskState( + self._workflow_system_variables = { + SystemVariable.QUERY: message.query, + SystemVariable.FILES: application_generate_entity.files, + SystemVariable.CONVERSATION: conversation.id, + } + + self._task_state = AdvancedChatTaskState( usage=LLMUsage.empty_usage() ) - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - self._stream = stream if stream: self._stream_generate_routes = self._get_stream_generate_routes() else: self._stream_generate_routes = None - def process(self) -> Union[dict, Generator]: + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: @@ -153,11 +108,20 @@ def process(self) -> Union[dict, Generator]: db.session.close() if self._stream: - return self._process_stream_response() + generator = self._process_stream_response() + for stream_response in generator: + yield ChatbotAppStreamResponse( + conversation_id=self._conversation.id, + message_id=self._message.id, + created_at=int(self._message.created_at.timestamp()), + stream_response=stream_response + ) + + # yield "data: " + json.dumps(response) + "\n\n" else: return self._process_blocking_response() - def _process_blocking_response(self) -> dict: + def _process_blocking_response(self) -> ChatbotAppBlockingResponse: """ Process blocking response. :return: @@ -166,65 +130,64 @@ def _process_blocking_response(self) -> dict: event = queue_message.event if isinstance(event, QueueErrorEvent): - raise self._handle_error(event) + err = self._handle_error(event) + raise err elif isinstance(event, QueueRetrieverResourcesEvent): - self._task_state.metadata['retriever_resources'] = event.retriever_resources + self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): - annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + annotation = self._handle_annotation_reply(event) if annotation: - account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } - } - self._task_state.answer = annotation.content elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() + self._handle_workflow_start() elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) + self._handle_node_start(event) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) + self._handle_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._on_workflow_finished(event) + workflow_run = self._handle_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) - # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() - - self._task_state.answer = self._output_moderation_handler.moderation_completion( - completion=self._task_state.answer, - public_event=False - ) + # handle output moderation + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer # Save message self._save_message() - response = { - 'event': 'message', - 'task_id': self._application_generate_entity.task_id, - 'id': self._message.id, - 'message_id': self._message.id, - 'conversation_id': self._conversation.id, - 'mode': self._conversation.mode, - 'answer': self._task_state.answer, - 'metadata': {}, - 'created_at': int(self._message.created_at.timestamp()) - } - - if self._task_state.metadata: - response['metadata'] = self._get_response_metadata() - - return response + return self._to_blocking_response() else: continue + raise Exception('Queue listening stopped unexpectedly.') + + def _to_blocking_response(self) -> ChatbotAppBlockingResponse: + """ + To blocking response. + :return: + """ + extras = {} + if self._task_state.metadata: + extras['metadata'] = self._task_state.metadata + + response = ChatbotAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=ChatbotAppBlockingResponse.Data( + id=self._message.id, + mode=self._conversation.mode, + conversation_id=self._conversation.id, + message_id=self._message.id, + answer=self._task_state.answer, + created_at=int(self._message.created_at.timestamp()), + **extras + ) + ) + + return response + def _process_stream_response(self) -> Generator: """ Process stream response. @@ -234,81 +197,42 @@ def _process_stream_response(self) -> Generator: event = message.event if isinstance(event, QueueErrorEvent): - data = self._error_to_stream_response_data(self._handle_error(event)) - yield self._yield_response(data) + err = self._handle_error(event) + yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._on_workflow_start() - - response = { - 'event': 'workflow_started', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_run.id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'sequence_number': workflow_run.sequence_number, - 'created_at': int(workflow_run.created_at.timestamp()) - } - } - - yield self._yield_response(response) + workflow_run = self._handle_workflow_start() + yield self._workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._on_node_start(event) - - response = { - 'event': 'node_started', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_node_execution.workflow_run_id, - 'data': { - 'id': workflow_node_execution.id, - 'node_id': workflow_node_execution.node_id, - 'index': workflow_node_execution.index, - 'predecessor_node_id': workflow_node_execution.predecessor_node_id, - 'inputs': workflow_node_execution.inputs_dict, - 'created_at': int(workflow_node_execution.created_at.timestamp()) - } - } - - yield self._yield_response(response) + workflow_node_execution = self._handle_node_start(event) + + # search stream_generate_routes if node id is answer start at node + if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: + self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + + yield self._workflow_node_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._on_node_finished(event) - - if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - - response = { - 'event': 'node_finished', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_node_execution.workflow_run_id, - 'data': { - 'id': workflow_node_execution.id, - 'node_id': workflow_node_execution.node_id, - 'index': workflow_node_execution.index, - 'predecessor_node_id': workflow_node_execution.predecessor_node_id, - 'inputs': workflow_node_execution.inputs_dict, - 'process_data': workflow_node_execution.process_data_dict, - 'outputs': workflow_node_execution.outputs_dict, - 'status': workflow_node_execution.status, - 'error': workflow_node_execution.error, - 'elapsed_time': workflow_node_execution.elapsed_time, - 'execution_metadata': workflow_node_execution.execution_metadata_dict, - 'created_at': int(workflow_node_execution.created_at.timestamp()), - 'finished_at': int(workflow_node_execution.finished_at.timestamp()) - } - } - - yield self._yield_response(response) + workflow_node_execution = self._handle_node_finished(event) + + # stream outputs when node finished + self._generate_stream_outputs_when_node_finished() + + yield self._workflow_node_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._on_workflow_finished(event) + workflow_run = self._handle_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - data = self._error_to_stream_response_data(self._handle_error(err_event)) - yield self._yield_response(data) + yield self._error_to_stream_response(self._handle_error(err_event)) break self._queue_manager.publish( @@ -316,292 +240,54 @@ def _process_stream_response(self) -> Generator: PublishFrom.TASK_PIPELINE ) - workflow_run_response = { - 'event': 'workflow_finished', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_run.id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'status': workflow_run.status, - 'outputs': workflow_run.outputs_dict, - 'error': workflow_run.error, - 'elapsed_time': workflow_run.elapsed_time, - 'total_tokens': workflow_run.total_tokens, - 'total_steps': workflow_run.total_steps, - 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) - } - } - - yield self._yield_response(workflow_run_response) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) elif isinstance(event, QueueAdvancedChatMessageEndEvent): - # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() - - self._task_state.answer = self._output_moderation_handler.moderation_completion( - completion=self._task_state.answer, - public_event=False - ) - - self._output_moderation_handler = None - - replace_response = { - 'event': 'message_replace', - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'conversation_id': self._conversation.id, - 'answer': self._task_state.answer, - 'created_at': int(self._message.created_at.timestamp()) - } - - yield self._yield_response(replace_response) + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message self._save_message() - response = { - 'event': 'message_end', - 'task_id': self._application_generate_entity.task_id, - 'id': self._message.id, - 'message_id': self._message.id, - 'conversation_id': self._conversation.id, - } - - if self._task_state.metadata: - response['metadata'] = self._get_response_metadata() - - yield self._yield_response(response) + yield self._message_end_to_stream_response() elif isinstance(event, QueueRetrieverResourcesEvent): - self._task_state.metadata['retriever_resources'] = event.retriever_resources + self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): - annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + annotation = self._handle_annotation_reply(event) if annotation: - account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } - } - self._task_state.answer = annotation.content elif isinstance(event, QueueMessageFileEvent): - message_file: MessageFile = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) - # get extension - if '.' in message_file.url: - extension = f'.{message_file.url.split(".")[-1]}' - if len(extension) > 10: - extension = '.bin' - else: - extension = '.bin' - # add sign url - url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) - - if message_file: - response = { - 'event': 'message_file', - 'conversation_id': self._conversation.id, - 'id': message_file.id, - 'type': message_file.type, - 'belongs_to': message_file.belongs_to or 'user', - 'url': url - } - - yield self._yield_response(response) + response = self._message_file_to_stream_response(event) + if response: + yield response elif isinstance(event, QueueTextChunkEvent): + delta_text = event.text + if delta_text is None: + continue + if not self._is_stream_out_support( event=event ): continue - delta_text = event.text - if delta_text is None: + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(delta_text) + if should_direct_answer: continue - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): - # stop subscribe new token when output moderation should direct output - self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE - ) - - self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE - ) - continue - else: - self._output_moderation_handler.append_new_token(delta_text) - self._task_state.answer += delta_text - response = self._handle_chunk(delta_text) - yield self._yield_response(response) + yield self._message_to_stream_response(delta_text, self._message.id) elif isinstance(event, QueueMessageReplaceEvent): - response = { - 'event': 'message_replace', - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'conversation_id': self._conversation.id, - 'answer': event.text, - 'created_at': int(self._message.created_at.timestamp()) - } - - yield self._yield_response(response) + yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): - yield "event: ping\n\n" + yield self._ping_stream_response() else: continue - def _on_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs={ - SystemVariable.QUERY: self._message.query, - SystemVariable.FILES: self._application_generate_entity.files, - SystemVariable.CONVERSATION: self._conversation.id, - } - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=event.node_type, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._task_state.total_steps += 1 - - db.session.close() - - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: - self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] - - # stream outputs from start - self._generate_stream_outputs_when_node_start() - - return workflow_node_execution - - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=event.execution_metadata - ) - - if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error - ) - - # stream outputs when node finished - self._generate_stream_outputs_when_node_finished() - - db.session.close() - - return workflow_node_execution - - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ - -> WorkflowRun: - workflow_run = (db.session.query(WorkflowRun) - .filter(WorkflowRun.id == self._task_state.workflow_run_id).first()) - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs - ) - - self._task_state.workflow_run_id = workflow_run.id - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - - db.session.close() - - return workflow_run - def _save_message(self) -> None: """ Save message. @@ -636,140 +322,20 @@ def _save_message(self) -> None: extras=self._application_generate_entity.extras ) - def _handle_chunk(self, text: str) -> dict: - """ - Handle completed event. - :param text: text - :return: - """ - response = { - 'event': 'message', - 'id': self._message.id, - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'conversation_id': self._conversation.id, - 'answer': text, - 'created_at': int(self._message.created_at.timestamp()) - } - - return response - - def _handle_error(self, event: QueueErrorEvent) -> Exception: + def _message_end_to_stream_response(self) -> MessageEndStreamResponse: """ - Handle error event. - :param event: event + Message end to stream response. :return: """ - logger.debug("error: %s", event.error) - e = event.error + extras = {} + if self._task_state.metadata: + extras['metadata'] = self._task_state.metadata - if isinstance(e, InvokeAuthorizationError): - return InvokeAuthorizationError('Incorrect API key provided') - elif isinstance(e, InvokeError) or isinstance(e, ValueError): - return e - else: - return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) - - def _error_to_stream_response_data(self, e: Exception) -> dict: - """ - Error to stream response. - :param e: exception - :return: - """ - error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, - QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 - }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} - } - - # Determine the response based on the type of exception - data = None - for k, v in error_responses.items(): - if isinstance(e, k): - data = v - - if data: - data.setdefault('message', getattr(e, 'description', str(e))) - else: - logging.error(e) - data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 - } - - return { - 'event': 'error', - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - **data - } - - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - else: - metadata['retriever_resources'] = [] - for resource in self._task_state.metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata - - def _yield_response(self, response: dict) -> str: - """ - Yield response. - :param response: response - :return: - """ - return "data: " + json.dumps(response) + "\n\n" - - def _init_output_moderation(self) -> Optional[OutputModeration]: - """ - Init output moderation. - :return: - """ - app_config = self._application_generate_entity.app_config - sensitive_word_avoidance = app_config.sensitive_word_avoidance - - if sensitive_word_avoidance: - return OutputModeration( - tenant_id=app_config.tenant_id, - app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager - ) + return MessageEndStreamResponse( + task_id=self._application_generate_entity.task_id, + id=self._message.id, + **extras + ) def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]: """ @@ -840,34 +406,6 @@ def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \ return start_node_id - def _generate_stream_outputs_when_node_start(self) -> None: - """ - Generate stream outputs. - :return: - """ - if not self._task_state.current_stream_generate_state: - return - - for route_chunk in self._task_state.current_stream_generate_state.generate_route: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - for token in route_chunk.text: - self._queue_manager.publish( - QueueTextChunkEvent( - text=token - ), PublishFrom.TASK_PIPELINE - ) - time.sleep(0.01) - - self._task_state.current_stream_generate_state.current_route_position += 1 - else: - break - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route): - self._task_state.current_stream_generate_state = None - def _generate_stream_outputs_when_node_finished(self) -> None: """ Generate stream outputs. @@ -985,3 +523,29 @@ def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: return False return True + + def _handle_output_moderation_chunk(self, text: str) -> bool: + """ + Handle output moderation chunk. + :param text: text + :return: True if output moderation should direct output, otherwise False + """ + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + return True + else: + self._output_moderation_handler.append_new_token(text) + + return False diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index cc9b0785f56106..f3f439b12df104 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -11,6 +11,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -30,7 +31,7 @@ def generate(self, app_model: App, args: Any, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator]: + -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -141,14 +142,20 @@ def generate(self, app_model: App, worker_thread.start() # return response or stream generator - return self._handle_response( + response = self._handle_response( application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, message=message, + user=user, stream=stream ) + return AgentChatAppGenerateResponseConverter.convert( + response=response, + invoke_from=invoke_from + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py new file mode 100644 index 00000000000000..bd91c5269e2c9f --- /dev/null +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -0,0 +1,107 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = ChatbotAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + 'event': 'message', + 'task_id': blocking_response.task_id, + 'id': blocking_response.data.id, + 'message_id': blocking_response.data.message_id, + 'conversation_id': blocking_response.data.conversation_id, + 'mode': blocking_response.data.mode, + 'answer': blocking_response.data.answer, + 'metadata': blocking_response.data.metadata, + 'created_at': blocking_response.data.created_at + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get('metadata', {}) + response['metadata'] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'conversation_id': chunk.conversation_id, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'conversation_id': chunk.conversation_id, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + sub_stream_response_dict = sub_stream_response.to_dict() + if isinstance(sub_stream_response, MessageEndStreamResponse): + metadata = sub_stream_response_dict.get('metadata', {}) + sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + + response_chunk.update(sub_stream_response_dict) + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py new file mode 100644 index 00000000000000..cbc07b1c702551 --- /dev/null +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator +from typing import Union + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse + + +class AppGenerateResponseConverter(ABC): + _blocking_response_type: type[AppBlockingResponse] + + @classmethod + def convert(cls, response: Union[ + AppBlockingResponse, + Generator[AppStreamResponse, None, None] + ], invoke_from: InvokeFrom) -> Union[ + dict, + Generator[str, None, None] + ]: + if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: + if isinstance(response, cls._blocking_response_type): + return cls.convert_blocking_full_response(response) + else: + for chunk in cls.convert_stream_full_response(response): + yield f'data: {chunk}\n\n' + else: + if isinstance(response, cls._blocking_response_type): + return cls.convert_blocking_simple_response(response) + else: + for chunk in cls.convert_stream_simple_response(response): + yield f'data: {chunk}\n\n' + + @classmethod + @abstractmethod + def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict: + raise NotImplementedError + + @classmethod + @abstractmethod + def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict: + raise NotImplementedError + + @classmethod + @abstractmethod + def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + raise NotImplementedError + + @classmethod + @abstractmethod + def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + raise NotImplementedError + + @classmethod + def _get_simple_metadata(cls, metadata: dict) -> dict: + """ + Get simple metadata. + :param metadata: metadata + :return: + """ + # show_retrieve_source + if 'retriever_resources' in metadata: + metadata['retriever_resources'] = [] + for resource in metadata['retriever_resources']: + metadata['retriever_resources'].append({ + 'segment_id': resource['segment_id'], + 'position': resource['position'], + 'document_name': resource['document_name'], + 'score': resource['score'], + 'content': resource['content'], + }) + + # show annotation reply + if 'annotation_reply' in metadata: + del metadata['annotation_reply'] + + # show usage + if 'usage' in metadata: + del metadata['usage'] + + return metadata \ No newline at end of file diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 58287ba6587d09..3d3ee7e446accb 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -12,6 +12,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom @@ -30,7 +31,7 @@ def generate(self, app_model: App, args: Any, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator]: + -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -141,14 +142,20 @@ def generate(self, app_model: App, worker_thread.start() # return response or stream generator - return self._handle_response( + response = self._handle_response( application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, message=message, + user=user, stream=stream ) + return ChatAppGenerateResponseConverter.convert( + response=response, + invoke_from=invoke_from + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: ChatAppGenerateEntity, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py new file mode 100644 index 00000000000000..898561e01aa58a --- /dev/null +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -0,0 +1,107 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = ChatbotAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + 'event': 'message', + 'task_id': blocking_response.task_id, + 'id': blocking_response.data.id, + 'message_id': blocking_response.data.message_id, + 'conversation_id': blocking_response.data.conversation_id, + 'mode': blocking_response.data.mode, + 'answer': blocking_response.data.answer, + 'metadata': blocking_response.data.metadata, + 'created_at': blocking_response.data.created_at + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get('metadata', {}) + response['metadata'] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'conversation_id': chunk.conversation_id, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(ChatbotAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'conversation_id': chunk.conversation_id, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + sub_stream_response_dict = sub_stream_response.to_dict() + if isinstance(sub_stream_response, MessageEndStreamResponse): + metadata = sub_stream_response_dict.get('metadata', {}) + sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + + response_chunk.update(sub_stream_response_dict) + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index fb6246972075cf..ad979eb8404ee4 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -12,6 +12,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner +from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom @@ -32,7 +33,7 @@ def generate(self, app_model: App, args: Any, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator]: + -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -133,14 +134,20 @@ def generate(self, app_model: App, worker_thread.start() # return response or stream generator - return self._handle_response( + response = self._handle_response( application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, message=message, + user=user, stream=stream ) + return CompletionAppGenerateResponseConverter.convert( + response=response, + invoke_from=invoke_from + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, @@ -189,7 +196,7 @@ def generate_more_like_this(self, app_model: App, user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator]: + -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -289,5 +296,6 @@ def generate_more_like_this(self, app_model: App, queue_manager=queue_manager, conversation=conversation, message=message, + user=user, stream=stream ) diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py new file mode 100644 index 00000000000000..0570f815a6e581 --- /dev/null +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -0,0 +1,104 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) + + +class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = CompletionAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + response = { + 'event': 'message', + 'task_id': blocking_response.task_id, + 'id': blocking_response.data.id, + 'message_id': blocking_response.data.message_id, + 'mode': blocking_response.data.mode, + 'answer': blocking_response.data.answer, + 'metadata': blocking_response.data.metadata, + 'created_at': blocking_response.data.created_at + } + + return response + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + response = cls.convert_blocking_full_response(blocking_response) + + metadata = response.get('metadata', {}) + response['metadata'] = cls._get_simple_metadata(metadata) + + return response + + @classmethod + def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(CompletionAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(CompletionAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'message_id': chunk.message_id, + 'created_at': chunk.created_at + } + + sub_stream_response_dict = sub_stream_response.to_dict() + if isinstance(sub_stream_response, MessageEndStreamResponse): + metadata = sub_stream_response_dict.get('metadata', {}) + sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + + response_chunk.update(sub_stream_response_dict) + yield json.dumps(response_chunk) diff --git a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py deleted file mode 100644 index 412029b02491f8..00000000000000 --- a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py +++ /dev/null @@ -1,600 +0,0 @@ -import json -import logging -import time -from collections.abc import Generator -from typing import Optional, Union, cast - -from pydantic import BaseModel - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import ( - AgentChatAppGenerateEntity, - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - InvokeFrom, -) -from core.app.entities.queue_entities import ( - QueueAgentMessageEvent, - QueueAgentThoughtEvent, - QueueAnnotationReplyEvent, - QueueErrorEvent, - QueueLLMChunkEvent, - QueueMessageEndEvent, - QueueMessageFileEvent, - QueueMessageReplaceEvent, - QueuePingEvent, - QueueRetrieverResourcesEvent, - QueueStopEvent, -) -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, -) -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder -from core.moderation.output_moderation import ModerationRule, OutputModeration -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.tools.tool_file_manager import ToolFileManager -from events.message_event import message_was_created -from extensions.ext_database import db -from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile -from services.annotation_service import AppAnnotationService - -logger = logging.getLogger(__name__) - - -class TaskState(BaseModel): - """ - TaskState entity - """ - llm_result: LLMResult - metadata: dict = {} - - -class EasyUIBasedGenerateTaskPipeline: - """ - EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. - """ - - def __init__(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - """ - self._application_generate_entity = application_generate_entity - self._model_config = application_generate_entity.model_config - self._queue_manager = queue_manager - self._conversation = conversation - self._message = message - self._task_state = TaskState( - llm_result=LLMResult( - model=self._model_config.model, - prompt_messages=[], - message=AssistantPromptMessage(content=""), - usage=LLMUsage.empty_usage() - ) - ) - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - - def process(self, stream: bool) -> Union[dict, Generator]: - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._conversation) - db.session.refresh(self._message) - db.session.close() - - if stream: - return self._process_stream_response() - else: - return self._process_blocking_response() - - def _process_blocking_response(self) -> dict: - """ - Process blocking response. - :return: - """ - for queue_message in self._queue_manager.listen(): - event = queue_message.event - - if isinstance(event, QueueErrorEvent): - raise self._handle_error(event) - elif isinstance(event, QueueRetrieverResourcesEvent): - self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, QueueAnnotationReplyEvent): - annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) - if annotation: - account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } - } - - self._task_state.llm_result.message.content = annotation.content - elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): - if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result - else: - model_config = self._model_config - model = model_config.model - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # calculate num tokens - prompt_tokens = 0 - if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, - self._task_state.llm_result.prompt_messages - ) - - completion_tokens = 0 - if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, - [self._task_state.llm_result.message] - ) - - credentials = model_config.credentials - - # transform usage - self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens - ) - - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) - - # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() - - self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion( - completion=self._task_state.llm_result.message.content, - public_event=False - ) - - # Save message - self._save_message(self._task_state.llm_result) - - response = { - 'event': 'message', - 'task_id': self._application_generate_entity.task_id, - 'id': self._message.id, - 'message_id': self._message.id, - 'mode': self._conversation.mode, - 'answer': self._task_state.llm_result.message.content, - 'metadata': {}, - 'created_at': int(self._message.created_at.timestamp()) - } - - if self._conversation.mode != AppMode.COMPLETION.value: - response['conversation_id'] = self._conversation.id - - if self._task_state.metadata: - response['metadata'] = self._get_response_metadata() - - return response - else: - continue - - def _process_stream_response(self) -> Generator: - """ - Process stream response. - :return: - """ - for message in self._queue_manager.listen(): - event = message.event - - if isinstance(event, QueueErrorEvent): - data = self._error_to_stream_response_data(self._handle_error(event)) - yield self._yield_response(data) - break - elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): - if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result - else: - model_config = self._model_config - model = model_config.model - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # calculate num tokens - prompt_tokens = 0 - if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, - self._task_state.llm_result.prompt_messages - ) - - completion_tokens = 0 - if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, - [self._task_state.llm_result.message] - ) - - credentials = model_config.credentials - - # transform usage - self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens - ) - - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) - - # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() - - self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion( - completion=self._task_state.llm_result.message.content, - public_event=False - ) - - self._output_moderation_handler = None - - replace_response = { - 'event': 'message_replace', - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'answer': self._task_state.llm_result.message.content, - 'created_at': int(self._message.created_at.timestamp()) - } - - if self._conversation.mode != AppMode.COMPLETION.value: - replace_response['conversation_id'] = self._conversation.id - - yield self._yield_response(replace_response) - - # Save message - self._save_message(self._task_state.llm_result) - - response = { - 'event': 'message_end', - 'task_id': self._application_generate_entity.task_id, - 'id': self._message.id, - 'message_id': self._message.id, - } - - if self._conversation.mode != AppMode.COMPLETION.value: - response['conversation_id'] = self._conversation.id - - if self._task_state.metadata: - response['metadata'] = self._get_response_metadata() - - yield self._yield_response(response) - elif isinstance(event, QueueRetrieverResourcesEvent): - self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, QueueAnnotationReplyEvent): - annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) - if annotation: - account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } - } - - self._task_state.llm_result.message.content = annotation.content - elif isinstance(event, QueueAgentThoughtEvent): - agent_thought: MessageAgentThought = ( - db.session.query(MessageAgentThought) - .filter(MessageAgentThought.id == event.agent_thought_id) - .first() - ) - db.session.refresh(agent_thought) - db.session.close() - - if agent_thought: - response = { - 'event': 'agent_thought', - 'id': agent_thought.id, - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'position': agent_thought.position, - 'thought': agent_thought.thought, - 'observation': agent_thought.observation, - 'tool': agent_thought.tool, - 'tool_labels': agent_thought.tool_labels, - 'tool_input': agent_thought.tool_input, - 'created_at': int(self._message.created_at.timestamp()), - 'message_files': agent_thought.files - } - - if self._conversation.mode != AppMode.COMPLETION.value: - response['conversation_id'] = self._conversation.id - - yield self._yield_response(response) - elif isinstance(event, QueueMessageFileEvent): - message_file: MessageFile = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) - db.session.close() - - # get extension - if '.' in message_file.url: - extension = f'.{message_file.url.split(".")[-1]}' - if len(extension) > 10: - extension = '.bin' - else: - extension = '.bin' - # add sign url - url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) - - if message_file: - response = { - 'event': 'message_file', - 'id': message_file.id, - 'type': message_file.type, - 'belongs_to': message_file.belongs_to or 'user', - 'url': url - } - - if self._conversation.mode != AppMode.COMPLETION.value: - response['conversation_id'] = self._conversation.id - - yield self._yield_response(response) - - elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): - chunk = event.chunk - delta_text = chunk.delta.message.content - if delta_text is None: - continue - - if not self._task_state.llm_result.prompt_messages: - self._task_state.llm_result.prompt_messages = chunk.prompt_messages - - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): - # stop subscribe new token when output moderation should direct output - self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( - QueueLLMChunkEvent( - chunk=LLMResultChunk( - model=self._task_state.llm_result.model, - prompt_messages=self._task_state.llm_result.prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) - ) - ), PublishFrom.TASK_PIPELINE - ) - - self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE - ) - continue - else: - self._output_moderation_handler.append_new_token(delta_text) - - self._task_state.llm_result.message.content += delta_text - response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent)) - yield self._yield_response(response) - elif isinstance(event, QueueMessageReplaceEvent): - response = { - 'event': 'message_replace', - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'answer': event.text, - 'created_at': int(self._message.created_at.timestamp()) - } - - if self._conversation.mode != AppMode.COMPLETION.value: - response['conversation_id'] = self._conversation.id - - yield self._yield_response(response) - elif isinstance(event, QueuePingEvent): - yield "event: ping\n\n" - else: - continue - - def _save_message(self, llm_result: LLMResult) -> None: - """ - Save message. - :param llm_result: llm result - :return: - """ - usage = llm_result.usage - - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() - - self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( - self._model_config.mode, - self._task_state.llm_result.prompt_messages - ) - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ - if llm_result.message.content else '' - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.total_price = usage.total_price - self._message.currency = usage.currency - - db.session.commit() - - message_was_created.send( - self._message, - application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [ - AppMode.AGENT_CHAT, - AppMode.CHAT - ] and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras - ) - - def _handle_chunk(self, text: str, agent: bool = False) -> dict: - """ - Handle completed event. - :param text: text - :return: - """ - response = { - 'event': 'message' if not agent else 'agent_message', - 'id': self._message.id, - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - 'answer': text, - 'created_at': int(self._message.created_at.timestamp()) - } - - if self._conversation.mode != AppMode.COMPLETION.value: - response['conversation_id'] = self._conversation.id - - return response - - def _handle_error(self, event: QueueErrorEvent) -> Exception: - """ - Handle error event. - :param event: event - :return: - """ - logger.debug("error: %s", event.error) - e = event.error - - if isinstance(e, InvokeAuthorizationError): - return InvokeAuthorizationError('Incorrect API key provided') - elif isinstance(e, InvokeError) or isinstance(e, ValueError): - return e - else: - return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) - - def _error_to_stream_response_data(self, e: Exception) -> dict: - """ - Error to stream response. - :param e: exception - :return: - """ - error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, - QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 - }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} - } - - # Determine the response based on the type of exception - data = None - for k, v in error_responses.items(): - if isinstance(e, k): - data = v - - if data: - data.setdefault('message', getattr(e, 'description', str(e))) - else: - logging.error(e) - data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 - } - - return { - 'event': 'error', - 'task_id': self._application_generate_entity.task_id, - 'message_id': self._message.id, - **data - } - - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - else: - metadata['retriever_resources'] = [] - for resource in self._task_state.metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata - - def _yield_response(self, response: dict) -> str: - """ - Yield response. - :param response: response - :return: - """ - return "data: " + json.dumps(response) + "\n\n" - - def _init_output_moderation(self) -> Optional[OutputModeration]: - """ - Init output moderation. - :return: - """ - app_config = self._application_generate_entity.app_config - sensitive_word_avoidance = app_config.sensitive_word_avoidance - - if sensitive_word_avoidance: - return OutputModeration( - tenant_id=app_config.tenant_id, - app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager - ) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 5e676c40bd56f5..2d480d71564205 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,7 +8,6 @@ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException -from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -17,6 +16,13 @@ CompletionAppGenerateEntity, InvokeFrom, ) +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account @@ -30,21 +36,28 @@ class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ], queue_manager: AppQueueManager, conversation: Conversation, message: Message, - stream: bool = False) -> Union[dict, Generator]: + user: Union[Account, EndUser], + stream: bool = False) \ + -> Union[ + ChatbotAppBlockingResponse, + CompletionAppBlockingResponse, + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + ]: """ Handle response. :param application_generate_entity: application generate entity :param queue_manager: queue manager :param conversation: conversation :param message: message + :param user: user :param stream: is stream :return: """ @@ -53,11 +66,13 @@ def _handle_response(self, application_generate_entity: Union[ application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, + user=user, + stream=stream ) try: - return generate_task_pipeline.process(stream=stream) + return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedException() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index b1a70a83ba42e3..b3721cfae97694 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -13,8 +13,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db @@ -32,7 +34,7 @@ def generate(self, app_model: App, args: dict, invoke_from: InvokeFrom, stream: bool = True) \ - -> Union[dict, Generator]: + -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -93,7 +95,7 @@ def generate(self, app_model: App, worker_thread.start() # return response or stream generator - return self._handle_response( + response = self._handle_response( application_generate_entity=application_generate_entity, workflow=workflow, queue_manager=queue_manager, @@ -101,6 +103,11 @@ def generate(self, app_model: App, stream=stream ) + return WorkflowAppGenerateResponseConverter.convert( + response=response, + invoke_from=invoke_from + ) + def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: @@ -141,7 +148,10 @@ def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntit workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], - stream: bool = False) -> Union[dict, Generator]: + stream: bool = False) -> Union[ + WorkflowAppBlockingResponse, + Generator[WorkflowAppStreamResponse, None, None] + ]: """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py new file mode 100644 index 00000000000000..6dec3430ded0ed --- /dev/null +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -0,0 +1,66 @@ +import json +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) + + +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = WorkflowAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + return blocking_response.to_dict() + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + return cls.convert_blocking_full_response(blocking_response) + + @classmethod + def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield 'ping' + continue + + response_chunk = { + 'event': sub_stream_response.event.value, + 'workflow_run_id': chunk.workflow_run_id, + } + + response_chunk.update(sub_stream_response.to_dict()) + yield json.dumps(response_chunk) + + @classmethod + def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ + -> Generator[str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + return cls.convert_stream_full_response(stream_response) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index cd1ea4c81eaf79..1b43ed9d3bb1c1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,13 +1,8 @@ -import json import logging -import time from collections.abc import Generator -from typing import Optional, Union - -from pydantic import BaseModel, Extra +from typing import Any, Union from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, @@ -25,10 +20,16 @@ QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.moderation.output_moderation import ModerationRule, OutputModeration -from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVariable +from core.app.entities.task_entities import ( + TextChunkStreamResponse, + TextReplaceStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, + WorkflowTaskState, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage +from core.workflow.entities.node_entities import SystemVariable from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -36,54 +37,21 @@ Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, - WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) logger = logging.getLogger(__name__) -class TaskState(BaseModel): - """ - TaskState entity - """ - class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - start_at: float - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - answer: str = "" - metadata: dict = {} - - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - running_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - -class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): +class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: WorkflowTaskState + _application_generate_entity: WorkflowAppGenerateEntity + _workflow_system_variables: dict[SystemVariable, Any] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, @@ -96,18 +64,18 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, :param workflow: workflow :param queue_manager: queue manager :param user: user - :param stream: is stream + :param stream: is streamed """ - self._application_generate_entity = application_generate_entity + super().__init__(application_generate_entity, queue_manager, user, stream) + self._workflow = workflow - self._queue_manager = queue_manager - self._user = user - self._task_state = TaskState() - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - self._stream = stream - - def process(self) -> Union[dict, Generator]: + self._workflow_system_variables = { + SystemVariable.FILES: application_generate_entity.files, + } + + self._task_state = WorkflowTaskState() + + def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: @@ -117,11 +85,16 @@ def process(self) -> Union[dict, Generator]: db.session.close() if self._stream: - return self._process_stream_response() + generator = self._process_stream_response() + for stream_response in generator: + yield WorkflowAppStreamResponse( + workflow_run_id=self._task_state.workflow_run_id, + stream_response=stream_response + ) else: return self._process_blocking_response() - def _process_blocking_response(self) -> dict: + def _process_blocking_response(self) -> WorkflowAppBlockingResponse: """ Process blocking response. :return: @@ -130,49 +103,56 @@ def _process_blocking_response(self) -> dict: event = queue_message.event if isinstance(event, QueueErrorEvent): - raise self._handle_error(event) + err = self._handle_error(event) + raise err elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() + self._handle_workflow_start() elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) + self._handle_node_start(event) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) + self._handle_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._on_workflow_finished(event) - - # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() + workflow_run = self._handle_workflow_finished(event) - self._task_state.answer = self._output_moderation_handler.moderation_completion( - completion=self._task_state.answer, - public_event=False - ) + # handle output moderation + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer # save workflow app log self._save_workflow_app_log(workflow_run) - response = { - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_run.id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'status': workflow_run.status, - 'outputs': workflow_run.outputs_dict, - 'error': workflow_run.error, - 'elapsed_time': workflow_run.elapsed_time, - 'total_tokens': workflow_run.total_tokens, - 'total_steps': workflow_run.total_steps, - 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) - } - } - - return response + return self._to_blocking_response(workflow_run) else: continue + raise Exception('Queue listening stopped unexpectedly.') + + def _to_blocking_response(self, workflow_run: WorkflowRun) -> WorkflowAppBlockingResponse: + """ + To blocking response. + :param workflow_run: workflow run + :return: + """ + response = WorkflowAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=workflow_run.id, + data=WorkflowAppBlockingResponse.Data( + id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + status=workflow_run.status, + outputs=workflow_run.outputs_dict, + error=workflow_run.error, + elapsed_time=workflow_run.elapsed_time, + total_tokens=workflow_run.total_tokens, + total_steps=workflow_run.total_steps, + created_at=int(workflow_run.created_at.timestamp()), + finished_at=int(workflow_run.finished_at.timestamp()) + ) + ) + + return response + def _process_stream_response(self) -> Generator: """ Process stream response. @@ -182,281 +162,60 @@ def _process_stream_response(self) -> Generator: event = message.event if isinstance(event, QueueErrorEvent): - data = self._error_to_stream_response_data(self._handle_error(event)) - yield self._yield_response(data) + err = self._handle_error(event) + yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._on_workflow_start() - - response = { - 'event': 'workflow_started', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_run.id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'sequence_number': workflow_run.sequence_number, - 'created_at': int(workflow_run.created_at.timestamp()) - } - } - - yield self._yield_response(response) + workflow_run = self._handle_workflow_start() + yield self._workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._on_node_start(event) - - response = { - 'event': 'node_started', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_node_execution.workflow_run_id, - 'data': { - 'id': workflow_node_execution.id, - 'node_id': workflow_node_execution.node_id, - 'index': workflow_node_execution.index, - 'predecessor_node_id': workflow_node_execution.predecessor_node_id, - 'inputs': workflow_node_execution.inputs_dict, - 'created_at': int(workflow_node_execution.created_at.timestamp()) - } - } - - yield self._yield_response(response) + workflow_node_execution = self._handle_node_start(event) + yield self._workflow_node_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._on_node_finished(event) - - response = { - 'event': 'node_finished', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_node_execution.workflow_run_id, - 'data': { - 'id': workflow_node_execution.id, - 'node_id': workflow_node_execution.node_id, - 'index': workflow_node_execution.index, - 'predecessor_node_id': workflow_node_execution.predecessor_node_id, - 'inputs': workflow_node_execution.inputs_dict, - 'process_data': workflow_node_execution.process_data_dict, - 'outputs': workflow_node_execution.outputs_dict, - 'status': workflow_node_execution.status, - 'error': workflow_node_execution.error, - 'elapsed_time': workflow_node_execution.elapsed_time, - 'execution_metadata': workflow_node_execution.execution_metadata_dict, - 'created_at': int(workflow_node_execution.created_at.timestamp()), - 'finished_at': int(workflow_node_execution.finished_at.timestamp()) - } - } - - yield self._yield_response(response) + workflow_node_execution = self._handle_node_finished(event) + yield self._workflow_node_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._on_workflow_finished(event) - - # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() - - self._task_state.answer = self._output_moderation_handler.moderation_completion( - completion=self._task_state.answer, - public_event=False - ) - - self._output_moderation_handler = None - - replace_response = { - 'event': 'text_replace', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, - 'data': { - 'text': self._task_state.answer - } - } + workflow_run = self._handle_workflow_finished(event) - yield self._yield_response(replace_response) + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + yield self._text_replace_to_stream_response(output_moderation_answer) # save workflow app log self._save_workflow_app_log(workflow_run) - workflow_run_response = { - 'event': 'workflow_finished', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': workflow_run.id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'status': workflow_run.status, - 'outputs': workflow_run.outputs_dict, - 'error': workflow_run.error, - 'elapsed_time': workflow_run.elapsed_time, - 'total_tokens': workflow_run.total_tokens, - 'total_steps': workflow_run.total_steps, - 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) if workflow_run.finished_at else None - } - } - - yield self._yield_response(workflow_run_response) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): - # stop subscribe new token when output moderation should direct output - self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE - ) - - self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE - ) - continue - else: - self._output_moderation_handler.append_new_token(delta_text) + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(delta_text) + if should_direct_answer: + continue self._task_state.answer += delta_text - response = self._handle_chunk(delta_text) - yield self._yield_response(response) + yield self._text_chunk_to_stream_response(delta_text) elif isinstance(event, QueueMessageReplaceEvent): - response = { - 'event': 'text_replace', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, - 'data': { - 'text': event.text - } - } - - yield self._yield_response(response) + yield self._text_replace_to_stream_response(event.text) elif isinstance(event, QueuePingEvent): - yield "event: ping\n\n" + yield self._ping_stream_response() else: continue - def _on_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs={ - SystemVariable.FILES: self._application_generate_entity.files - } - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - start_at=time.perf_counter() - ) - - self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._task_state.total_steps += 1 - - db.session.close() - - return workflow_node_execution - - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.running_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=event.execution_metadata - ) - - if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error - ) - - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] - - db.session.close() - - return workflow_node_execution - - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ - -> WorkflowRun: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs - ) - - self._task_state.workflow_run_id = workflow_run.id - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - - db.session.close() - - return workflow_run - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. @@ -486,103 +245,52 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: db.session.commit() db.session.close() - def _handle_chunk(self, text: str) -> dict: + def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: """ Handle completed event. :param text: text :return: """ - response = { - 'event': 'text_chunk', - 'workflow_run_id': self._task_state.workflow_run_id, - 'task_id': self._application_generate_entity.task_id, - 'data': { - 'text': text - } - } + response = TextChunkStreamResponse( + task_id=self._application_generate_entity.task_id, + data=TextChunkStreamResponse.Data(text=text) + ) return response - def _handle_error(self, event: QueueErrorEvent) -> Exception: - """ - Handle error event. - :param event: event - :return: + def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: """ - logger.debug("error: %s", event.error) - e = event.error - - if isinstance(e, InvokeAuthorizationError): - return InvokeAuthorizationError('Incorrect API key provided') - elif isinstance(e, InvokeError) or isinstance(e, ValueError): - return e - else: - return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) - - def _error_to_stream_response_data(self, e: Exception) -> dict: - """ - Error to stream response. - :param e: exception + Text replace to stream response. + :param text: text :return: """ - error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, - QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 - }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} - } - - # Determine the response based on the type of exception - data = None - for k, v in error_responses.items(): - if isinstance(e, k): - data = v - - if data: - data.setdefault('message', getattr(e, 'description', str(e))) - else: - logging.error(e) - data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 - } - - return { - 'event': 'error', - 'task_id': self._application_generate_entity.task_id, - **data - } + return TextReplaceStreamResponse( + task_id=self._application_generate_entity.task_id, + text=TextReplaceStreamResponse.Data(text=text) + ) - def _yield_response(self, response: dict) -> str: + def _handle_output_moderation_chunk(self, text: str) -> bool: """ - Yield response. - :param response: response - :return: + Handle output moderation chunk. + :param text: text + :return: True if output moderation should direct output, otherwise False """ - return "data: " + json.dumps(response) + "\n\n" + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + return True + else: + self._output_moderation_handler.append_new_token(text) - def _init_output_moderation(self) -> Optional[OutputModeration]: - """ - Init output moderation. - :return: - """ - app_config = self._application_generate_entity.app_config - sensitive_word_avoidance = app_config.sensitive_word_avoidance - - if sensitive_word_avoidance: - return OutputModeration( - tenant_id=app_config.tenant_id, - app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager - ) + return False diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py deleted file mode 100644 index 2b373d28e83957..00000000000000 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ /dev/null @@ -1,214 +0,0 @@ -import json -import time -from datetime import datetime -from typing import Optional, Union - -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeType -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser -from models.workflow import ( - CreatedByRole, - Workflow, - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, - WorkflowRunStatus, - WorkflowRunTriggeredFrom, -) - - -class WorkflowBasedGenerateTaskPipeline: - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 - new_sequence_number = max_sequence + 1 - - # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id - ) - - db.session.add(workflow_run) - db.session.commit() - db.session.refresh(workflow_run) - db.session.close() - - return workflow_run - - def _workflow_run_success(self, workflow_run: WorkflowRun, - start_at: float, - total_tokens: int, - total_steps: int, - outputs: Optional[dict] = None) -> WorkflowRun: - """ - Workflow run success - :param workflow_run: workflow run - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :return: - """ - workflow_run.status = WorkflowRunStatus.SUCCEEDED.value - workflow_run.outputs = outputs - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.utcnow() - - db.session.commit() - db.session.refresh(workflow_run) - db.session.close() - - return workflow_run - - def _workflow_run_failed(self, workflow_run: WorkflowRun, - start_at: float, - total_tokens: int, - total_steps: int, - status: WorkflowRunStatus, - error: str) -> WorkflowRun: - """ - Workflow run failed - :param workflow_run: workflow run - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param status: status - :param error: error message - :return: - """ - workflow_run.status = status.value - workflow_run.error = error - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.utcnow() - - db.session.commit() - db.session.refresh(workflow_run) - db.session.close() - - return workflow_run - - def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run: workflow run - :param node_id: node id - :param node_type: node type - :param node_title: node title - :param node_run_index: run index - :param predecessor_node_id: predecessor node id if exists - :return: - """ - # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by - ) - - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param inputs: inputs - :param process_data: process data - :param outputs: outputs - :param execution_metadata: execution metadata - :return: - """ - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None - workflow_node_execution.finished_at = datetime.utcnow() - - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message - :return: - """ - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.finished_at = datetime.utcnow() - - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py new file mode 100644 index 00000000000000..124f4759851aa7 --- /dev/null +++ b/api/core/app/entities/task_entities.py @@ -0,0 +1,395 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.entities import GenerateRouteChunk + + +class StreamGenerateRoute(BaseModel): + """ + StreamGenerateRoute entity + """ + answer_node_id: str + generate_route: list[GenerateRouteChunk] + current_route_position: int = 0 + + +class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution_id: str + node_type: NodeType + start_at: float + + +class TaskState(BaseModel): + """ + TaskState entity + """ + metadata: dict = {} + + +class EasyUITaskState(TaskState): + """ + EasyUITaskState entity + """ + llm_result: LLMResult + + +class WorkflowTaskState(TaskState): + """ + WorkflowTaskState entity + """ + answer: str = "" + + workflow_run_id: Optional[str] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 + + ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None + + +class AdvancedChatTaskState(WorkflowTaskState): + """ + AdvancedChatTaskState entity + """ + usage: LLMUsage + + current_stream_generate_state: Optional[StreamGenerateRoute] = None + + +class StreamEvent(Enum): + """ + Stream event + """ + PING = "ping" + ERROR = "error" + MESSAGE = "message" + MESSAGE_END = "message_end" + MESSAGE_FILE = "message_file" + MESSAGE_REPLACE = "message_replace" + AGENT_THOUGHT = "agent_thought" + AGENT_MESSAGE = "agent_message" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_FINISHED = "workflow_finished" + NODE_STARTED = "node_started" + NODE_FINISHED = "node_finished" + TEXT_CHUNK = "text_chunk" + TEXT_REPLACE = "text_replace" + + +class StreamResponse(BaseModel): + """ + StreamResponse entity + """ + event: StreamEvent + task_id: str + + def to_dict(self) -> dict: + return jsonable_encoder(self) + + +class ErrorStreamResponse(StreamResponse): + """ + ErrorStreamResponse entity + """ + event: StreamEvent = StreamEvent.ERROR + code: str + status: int + message: Optional[str] = None + + +class MessageStreamResponse(StreamResponse): + """ + MessageStreamResponse entity + """ + event: StreamEvent = StreamEvent.MESSAGE + id: str + answer: str + + +class MessageEndStreamResponse(StreamResponse): + """ + MessageEndStreamResponse entity + """ + event: StreamEvent = StreamEvent.MESSAGE_END + id: str + metadata: Optional[dict] = None + + +class MessageFileStreamResponse(StreamResponse): + """ + MessageFileStreamResponse entity + """ + event: StreamEvent = StreamEvent.MESSAGE_FILE + id: str + type: str + belongs_to: str + url: str + + +class MessageReplaceStreamResponse(StreamResponse): + """ + MessageReplaceStreamResponse entity + """ + event: StreamEvent = StreamEvent.MESSAGE_REPLACE + answer: str + + +class AgentThoughtStreamResponse(StreamResponse): + """ + AgentThoughtStreamResponse entity + """ + event: StreamEvent = StreamEvent.AGENT_THOUGHT + id: str + position: int + thought: Optional[str] = None + observation: Optional[str] = None + tool: Optional[str] = None + tool_labels: Optional[dict] = None + tool_input: Optional[str] = None + message_files: Optional[list[str]] = None + + +class AgentMessageStreamResponse(StreamResponse): + """ + AgentMessageStreamResponse entity + """ + event: StreamEvent = StreamEvent.AGENT_MESSAGE + id: str + answer: str + + +class WorkflowStartStreamResponse(StreamResponse): + """ + WorkflowStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + workflow_id: str + sequence_number: int + created_at: int + + event: StreamEvent = StreamEvent.WORKFLOW_STARTED + workflow_run_id: str + data: Data + + +class WorkflowFinishStreamResponse(StreamResponse): + """ + WorkflowFinishStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + workflow_id: str + sequence_number: int + status: str + outputs: Optional[dict] = None + error: Optional[str] = None + elapsed_time: float + total_tokens: int + total_steps: int + created_at: int + finished_at: int + + event: StreamEvent = StreamEvent.WORKFLOW_FINISHED + workflow_run_id: str + data: Data + + +class NodeStartStreamResponse(StreamResponse): + """ + NodeStartStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + created_at: int + + event: StreamEvent = StreamEvent.NODE_STARTED + workflow_run_id: str + data: Data + + +class NodeFinishStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + node_id: str + node_type: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + + event: StreamEvent = StreamEvent.NODE_FINISHED + workflow_run_id: str + data: Data + + +class TextChunkStreamResponse(StreamResponse): + """ + TextChunkStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + text: str + + event: StreamEvent = StreamEvent.TEXT_CHUNK + data: Data + + +class TextReplaceStreamResponse(StreamResponse): + """ + TextReplaceStreamResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + text: str + + event: StreamEvent = StreamEvent.TEXT_REPLACE + data: Data + + +class PingStreamResponse(StreamResponse): + """ + PingStreamResponse entity + """ + event: StreamEvent = StreamEvent.PING + + +class AppStreamResponse(BaseModel): + """ + AppStreamResponse entity + """ + stream_response: StreamResponse + + +class ChatbotAppStreamResponse(AppStreamResponse): + """ + ChatbotAppStreamResponse entity + """ + conversation_id: str + message_id: str + created_at: int + + +class CompletionAppStreamResponse(AppStreamResponse): + """ + CompletionAppStreamResponse entity + """ + message_id: str + created_at: int + + +class WorkflowAppStreamResponse(AppStreamResponse): + """ + WorkflowAppStreamResponse entity + """ + workflow_run_id: str + + +class AppBlockingResponse(BaseModel): + """ + AppBlockingResponse entity + """ + task_id: str + + def to_dict(self) -> dict: + return jsonable_encoder(self) + + +class ChatbotAppBlockingResponse(AppBlockingResponse): + """ + ChatbotAppBlockingResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + mode: str + conversation_id: str + message_id: str + answer: str + metadata: dict = {} + created_at: int + + data: Data + + +class CompletionAppBlockingResponse(AppBlockingResponse): + """ + CompletionAppBlockingResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + mode: str + message_id: str + answer: str + metadata: dict = {} + created_at: int + + data: Data + + +class WorkflowAppBlockingResponse(AppBlockingResponse): + """ + WorkflowAppBlockingResponse entity + """ + class Data(BaseModel): + """ + Data entity + """ + id: str + workflow_id: str + status: str + outputs: Optional[dict] = None + error: Optional[str] = None + elapsed_time: float + total_tokens: int + total_steps: int + created_at: int + finished_at: int + + workflow_run_id: str + data: Data diff --git a/api/core/app/task_pipeline/__init__.py b/api/core/app/task_pipeline/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py new file mode 100644 index 00000000000000..2606b56bcd8bc4 --- /dev/null +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -0,0 +1,153 @@ +import logging +import time +from typing import Optional, Union + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueErrorEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + PingStreamResponse, + TaskState, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from models.account import Account +from models.model import EndUser + +logger = logging.getLogger(__name__) + + +class BasedGenerateTaskPipeline: + """ + BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + _task_state: TaskState + _application_generate_entity: AppGenerateEntity + + def __init__(self, application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param user: user + :param stream: stream + """ + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._user = user + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + self._stream = stream + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return ErrorStreamResponse( + task_id=self._application_generate_entity.task_id, + **data + ) + + def _ping_stream_response(self) -> PingStreamResponse: + """ + Ping stream response. + :return: + """ + return PingStreamResponse(task_id=self._application_generate_entity.task_id) + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + queue_manager=self._queue_manager + ) + + def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + """ + Handle output moderation when task finished. + :param completion: completion + :return: + """ + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + completion = self._output_moderation_handler.moderation_completion( + completion=completion, + public_event=False + ) + + self._output_moderation_handler = None + + return completion + + return None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py new file mode 100644 index 00000000000000..c7c380e57c1c2f --- /dev/null +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -0,0 +1,445 @@ +import logging +import time +from collections.abc import Generator +from typing import Optional, Union, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + AgentMessageStreamResponse, + AgentThoughtStreamResponse, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + EasyUITaskState, + MessageEndStreamResponse, +) +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from events.message_event import message_was_created +from extensions.ext_database import db +from models.account import Account +from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought + +logger = logging.getLogger(__name__) + + +class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): + """ + EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + _task_state: EasyUITaskState + _application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ] + + def __init__(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param user: user + :param stream: stream + """ + super().__init__(application_generate_entity, queue_manager, user, stream) + self._model_config = application_generate_entity.model_config + self._conversation = conversation + self._message = message + + self._task_state = EasyUITaskState( + llm_result=LLMResult( + model=self._model_config.model, + prompt_messages=[], + message=AssistantPromptMessage(content=""), + usage=LLMUsage.empty_usage() + ) + ) + + def process(self) -> Union[ + ChatbotAppBlockingResponse, + CompletionAppBlockingResponse, + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + ]: + """ + Process generate task pipeline. + :return: + """ + db.session.refresh(self._conversation) + db.session.refresh(self._message) + db.session.close() + + if self._stream: + generator = self._process_stream_response() + for stream_response in generator: + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + yield CompletionAppStreamResponse( + message_id=self._message.id, + created_at=int(self._message.created_at.timestamp()), + stream_response=stream_response + ) + else: + yield ChatbotAppStreamResponse( + conversation_id=self._conversation.id, + message_id=self._message.id, + created_at=int(self._message.created_at.timestamp()), + stream_response=stream_response + ) + + # yield "data: " + json.dumps(response) + "\n\n" + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + err = self._handle_error(event) + raise err + elif isinstance(event, QueueRetrieverResourcesEvent): + self._handle_retriever_resources(event) + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = self._handle_annotation_reply(event) + if annotation: + self._task_state.llm_result.message.content = annotation.content + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): + if isinstance(event, QueueMessageEndEvent): + self._task_state.llm_result = event.llm_result + else: + self._handle_stop(event) + + # handle output moderation + output_moderation_answer = self._handle_output_moderation_when_task_finished( + self._task_state.llm_result.message.content + ) + if output_moderation_answer: + self._task_state.llm_result.message.content = output_moderation_answer + + # Save message + self._save_message() + + return self._to_blocking_response() + else: + continue + + raise Exception('Queue listening stopped unexpectedly.') + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + err = self._handle_error(event) + yield self._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): + if isinstance(event, QueueMessageEndEvent): + self._task_state.llm_result = event.llm_result + else: + self._handle_stop(event) + + # handle output moderation + output_moderation_answer = self._handle_output_moderation_when_task_finished( + self._task_state.llm_result.message.content + ) + if output_moderation_answer: + self._task_state.llm_result.message.content = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + # Save message + self._save_message() + + yield self._message_end_to_stream_response() + elif isinstance(event, QueueRetrieverResourcesEvent): + self._handle_retriever_resources(event) + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = self._handle_annotation_reply(event) + if annotation: + self._task_state.llm_result.message.content = annotation.content + elif isinstance(event, QueueAgentThoughtEvent): + yield self._agent_thought_to_stream_response(event) + elif isinstance(event, QueueMessageFileEvent): + response = self._message_file_to_stream_response(event) + if response: + yield response + elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): + chunk = event.chunk + delta_text = chunk.delta.message.content + if delta_text is None: + continue + + if not self._task_state.llm_result.prompt_messages: + self._task_state.llm_result.prompt_messages = chunk.prompt_messages + + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(delta_text) + if should_direct_answer: + continue + + self._task_state.llm_result.message.content += delta_text + + if isinstance(event, QueueLLMChunkEvent): + yield self._message_to_stream_response(delta_text, self._message.id) + else: + yield self._agent_message_to_stream_response(delta_text, self._message.id) + elif isinstance(event, QueueMessageReplaceEvent): + yield self._message_replace_to_stream_response(answer=event.text) + elif isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + else: + continue + + def _save_message(self) -> None: + """ + Save message. + :return: + """ + llm_result = self._task_state.llm_result + usage = llm_result.usage + + self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + + self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + self._model_config.mode, + self._task_state.llm_result.prompt_messages + ) + self._message.message_tokens = usage.prompt_tokens + self._message.message_unit_price = usage.prompt_unit_price + self._message.message_price_unit = usage.prompt_price_unit + self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ + if llm_result.message.content else '' + self._message.answer_tokens = usage.completion_tokens + self._message.answer_unit_price = usage.completion_unit_price + self._message.answer_price_unit = usage.completion_price_unit + self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.total_price = usage.total_price + self._message.currency = usage.currency + + db.session.commit() + + message_was_created.send( + self._message, + application_generate_entity=self._application_generate_entity, + conversation=self._conversation, + is_first_message=self._application_generate_entity.app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT + ] and self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras + ) + + def _handle_stop(self, event: QueueStopEvent) -> None: + """ + Handle stop. + :return: + """ + model_config = self._model_config + model = model_config.model + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + # calculate num tokens + prompt_tokens = 0 + if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: + prompt_tokens = model_type_instance.get_num_tokens( + model, + model_config.credentials, + self._task_state.llm_result.prompt_messages + ) + + completion_tokens = 0 + if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: + completion_tokens = model_type_instance.get_num_tokens( + model, + model_config.credentials, + [self._task_state.llm_result.message] + ) + + credentials = model_config.credentials + + # transform usage + self._task_state.llm_result.usage = model_type_instance._calc_response_usage( + model, + credentials, + prompt_tokens, + completion_tokens + ) + + def _to_blocking_response(self) -> ChatbotAppBlockingResponse: + """ + To blocking response. + :return: + """ + self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + + extras = {} + if self._task_state.metadata: + extras['metadata'] = self._task_state.metadata + + if self._conversation.mode != AppMode.COMPLETION.value: + response = CompletionAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=CompletionAppBlockingResponse.Data( + id=self._message.id, + mode=self._conversation.mode, + message_id=self._message.id, + answer=self._task_state.llm_result.message.content, + created_at=int(self._message.created_at.timestamp()), + **extras + ) + ) + else: + response = ChatbotAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=ChatbotAppBlockingResponse.Data( + id=self._message.id, + mode=self._conversation.mode, + conversation_id=self._conversation.id, + message_id=self._message.id, + answer=self._task_state.llm_result.message.content, + created_at=int(self._message.created_at.timestamp()), + **extras + ) + ) + + return response + + def _message_end_to_stream_response(self) -> MessageEndStreamResponse: + """ + Message end to stream response. + :return: + """ + self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + + extras = {} + if self._task_state.metadata: + extras['metadata'] = self._task_state.metadata + + return MessageEndStreamResponse( + task_id=self._application_generate_entity.task_id, + id=self._message.id, + **extras + ) + + def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: + """ + Agent message to stream response. + :param answer: answer + :param message_id: message id + :return: + """ + return AgentMessageStreamResponse( + task_id=self._application_generate_entity.task_id, + id=message_id, + answer=answer + ) + + def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: + """ + Agent thought to stream response. + :param event: agent thought event + :return: + """ + agent_thought: MessageAgentThought = ( + db.session.query(MessageAgentThought) + .filter(MessageAgentThought.id == event.agent_thought_id) + .first() + ) + db.session.refresh(agent_thought) + db.session.close() + + if agent_thought: + return AgentThoughtStreamResponse( + task_id=self._application_generate_entity.task_id, + id=agent_thought.id, + position=agent_thought.position, + thought=agent_thought.thought, + observation=agent_thought.observation, + tool=agent_thought.tool, + tool_labels=agent_thought.tool_labels, + tool_input=agent_thought.tool_input, + message_files=agent_thought.files + ) + + return None + + def _handle_output_moderation_chunk(self, text: str) -> bool: + """ + Handle output moderation chunk. + :param text: text + :return: True if output moderation should direct output, otherwise False + """ + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueLLMChunkEvent( + chunk=LLMResultChunk( + model=self._task_state.llm_result.model, + prompt_messages=self._task_state.llm_result.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) + ) + ) + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + return True + else: + self._output_moderation_handler.append_new_token(text) + + return False diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py new file mode 100644 index 00000000000000..305b560f95d505 --- /dev/null +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -0,0 +1,142 @@ +from typing import Optional, Union + +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueMessageFileEvent, + QueueRetrieverResourcesEvent, +) +from core.app.entities.task_entities import ( + AdvancedChatTaskState, + EasyUITaskState, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, +) +from core.tools.tool_file_manager import ToolFileManager +from extensions.ext_database import db +from models.model import MessageAnnotation, MessageFile +from services.annotation_service import AppAnnotationService + + +class MessageCycleManage: + _application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ] + _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + + def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + """ + Handle annotation reply. + :param event: event + :return: + """ + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + return annotation + + return None + + def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: + """ + Handle retriever resources. + :param event: event + :return: + """ + self._task_state.metadata['retriever_resources'] = event.retriever_resources + + def _get_response_metadata(self) -> dict: + """ + Get response metadata by invoke from. + :return: + """ + metadata = {} + + # show_retrieve_source + if 'retriever_resources' in self._task_state.metadata: + metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] + + # show annotation reply + if 'annotation_reply' in self._task_state.metadata: + metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] + + # show usage + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['usage'] = self._task_state.metadata['usage'] + + return metadata + + def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + """ + Message file to stream response. + :param event: event + :return: + """ + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter(MessageFile.id == event.message_file_id) + .first() + ) + + if message_file: + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + + return MessageFileStreamResponse( + task_id=self._application_generate_entity.task_id, + id=message_file.id, + type=message_file.type, + belongs_to=message_file.belongs_to or 'user', + url=url + ) + + return None + + def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse: + """ + Message to stream response. + :param answer: answer + :param message_id: message id + :return: + """ + return MessageStreamResponse( + task_id=self._application_generate_entity.task_id, + id=message_id, + answer=answer + ) + + def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: + """ + Message replace to stream response. + :param answer: answer + :return: + """ + return MessageReplaceStreamResponse( + task_id=self._application_generate_entity.task_id, + answer=answer + ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py new file mode 100644 index 00000000000000..1af2074c056e07 --- /dev/null +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -0,0 +1,457 @@ +import json +import time +from datetime import datetime +from typing import Any, Optional, Union + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + AdvancedChatTaskState, + NodeExecutionInfo, + NodeFinishStreamResponse, + NodeStartStreamResponse, + WorkflowFinishStreamResponse, + WorkflowStartStreamResponse, + WorkflowTaskState, +) +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from extensions.ext_database import db +from models.account import Account +from models.model import EndUser +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) + + +class WorkflowCycleManage: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] + _workflow_system_variables: dict[SystemVariable, Any] + + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + db.session.refresh(workflow_run) + db.session.close() + + return workflow_run + + def _workflow_run_success(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Optional[dict] = None) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :return: + """ + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + workflow_run.outputs = outputs + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_run) + db.session.close() + + return workflow_run + + def _workflow_run_failed(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + status: WorkflowRunStatus, + error: str) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param status: status + :param error: error message + :return: + """ + workflow_run.status = status.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_run) + db.session.close() + + return workflow_run + + def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, + node_id: str, + node_type: NodeType, + node_title: str, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: + """ + Init workflow node execution from workflow run + :param workflow_run: workflow run + :param node_id: node id + :param node_type: node type + :param node_title: node title + :param node_run_index: run index + :param predecessor_node_id: predecessor node id if exists + :return: + """ + # init workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node_id, + index=node_run_index, + node_id=node_id, + node_type=node_type.value, + title=node_title, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by + ) + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + """ + Workflow node execution success + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param inputs: inputs + :param process_data: process data + :param outputs: outputs + :param execution_metadata: execution metadata + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ + if execution_metadata else None + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + error: str) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param error: error message + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _workflow_start_to_stream_response(self, task_id: str, workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + """ + Workflow start to stream response. + :param task_id: task id + :param workflow_run: workflow run + :return: + """ + return WorkflowStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=WorkflowStartStreamResponse.Data( + id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + sequence_number=workflow_run.sequence_number, + created_at=int(workflow_run.created_at.timestamp()) + ) + ) + + def _workflow_finish_to_stream_response(self, task_id: str, workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + """ + Workflow finish to stream response. + :param task_id: task id + :param workflow_run: workflow run + :return: + """ + return WorkflowFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=WorkflowFinishStreamResponse.Data( + id=workflow_run.id, + workflow_id=workflow_run.workflow_id, + sequence_number=workflow_run.sequence_number, + status=workflow_run.status, + outputs=workflow_run.outputs_dict, + error=workflow_run.error, + elapsed_time=workflow_run.elapsed_time, + total_tokens=workflow_run.total_tokens, + total_steps=workflow_run.total_steps, + created_at=int(workflow_run.created_at.timestamp()), + finished_at=int(workflow_run.finished_at.timestamp()) + ) + ) + + def _workflow_node_start_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ + -> NodeStartStreamResponse: + """ + Workflow node start to stream response. + :param task_id: task id + :param workflow_node_execution: workflow node execution + :return: + """ + return NodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeStartStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + created_at=int(workflow_node_execution.created_at.timestamp()) + ) + ) + + def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ + -> NodeFinishStreamResponse: + """ + Workflow node finish to stream response. + :param task_id: task id + :param workflow_node_execution: workflow node execution + :return: + """ + return NodeFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeFinishStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()) + ) + ) + + def _handle_workflow_start(self) -> WorkflowRun: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs=self._workflow_system_variables + ) + + self._task_state.workflow_run_id = workflow_run.id + + db.session.close() + + return workflow_run + + def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + latest_node_execution_info = NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + node_type=event.node_type, + start_at=time.perf_counter() + ) + + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + + self._task_state.total_steps += 1 + + db.session.close() + + return workflow_node_execution + + def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: + current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=workflow_node_execution, + start_at=current_node_execution.start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + + if workflow_node_execution.node_type == NodeType.LLM.value: + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=workflow_node_execution, + start_at=current_node_execution.start_at, + error=event.error + ) + + db.session.close() + + return workflow_node_execution + + def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + + workflow_run = self._workflow_run_success( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=outputs + ) + + self._task_state.workflow_run_id = workflow_run.id + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + + db.session.close() + + return workflow_run diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 1b955c6edd2442..ad30bcfa079b35 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -5,7 +5,6 @@ from langchain.schema import BaseOutputParser from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT -from core.model_runtime.errors.invoke import InvokeError class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 905ee1f80da0e3..8b345dba004f24 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,8 +1,7 @@ -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional from pydantic import BaseModel -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index a501113dc313b8..a899157e6a895b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,12 +1,12 @@ import threading -from typing import cast, Any +from typing import Any, cast -from flask import current_app, Flask +from flask import Flask, current_app from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus -from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType @@ -14,12 +14,12 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rerank.rerank import RerankRunner from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment, Document +from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus default_retrieval_model = { diff --git a/api/libs/helper.py b/api/libs/helper.py index 3eb14c50f049e3..f9cf590b7ac94b 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -1,12 +1,16 @@ +import json import random import re import string import subprocess import uuid +from collections.abc import Generator from datetime import datetime from hashlib import sha256 +from typing import Union from zoneinfo import available_timezones +from flask import Response, stream_with_context from flask_restful import fields @@ -142,3 +146,14 @@ def get_remote_ip(request): def generate_text_hash(text: str) -> str: hash_text = str(text) + 'None' return sha256(hash_text.encode()).hexdigest() + + +def compact_generate_response(response: Union[dict, Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') diff --git a/api/services/completion_service.py b/api/services/app_generate_service.py similarity index 50% rename from api/services/completion_service.py rename to api/services/app_generate_service.py index eb31ccbb3bf1ed..185d9ba89f5474 100644 --- a/api/services/completion_service.py +++ b/api/services/app_generate_service.py @@ -1,20 +1,26 @@ from collections.abc import Generator from typing import Any, Union +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from models.model import Account, App, AppMode, EndUser +from services.workflow_service import WorkflowService -class CompletionService: +class AppGenerateService: @classmethod - def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, - invoke_from: InvokeFrom, streaming: bool = True) -> Union[dict, Generator]: + def generate(cls, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + streaming: bool = True) -> Union[dict, Generator[dict, None, None]]: """ - App Completion + App Content Generate :param app_model: app model :param user: user :param args: args @@ -46,8 +52,28 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from=invoke_from, stream=streaming ) + elif app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, invoke_from) + return AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, invoke_from) + return WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) else: - raise ValueError('Invalid app mode') + raise ValueError(f'Invalid app mode {app_model.mode}') @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], @@ -69,3 +95,27 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], invoke_from=invoke_from, stream=streaming ) + + @classmethod + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: + """ + Get workflow + :param app_model: app model + :param invoke_from: invoke from + :return: + """ + workflow_service = WorkflowService() + if invoke_from == InvokeFrom.DEBUGGER: + # fetch draft workflow by app_model + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError('Workflow not initialized') + else: + # fetch published workflow by app_model + workflow = workflow_service.get_published_workflow(app_model=app_model) + + if not workflow: + raise ValueError('Workflow not published') + + return workflow diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 55f2526fbfc827..a768a4a55bb85d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,22 +1,17 @@ import json import time -from collections.abc import Generator from datetime import datetime -from typing import Optional, Union +from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, EndUser +from models.model import App, AppMode from models.workflow import ( CreatedByRole, Workflow, @@ -167,63 +162,6 @@ def get_default_block_config(self, node_type: str, filters: Optional[dict] = Non workflow_engine_manager = WorkflowEngineManager() return workflow_engine_manager.get_default_config(node_type, filters) - def run_advanced_chat_draft_workflow(self, app_model: App, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom) -> Union[dict, Generator]: - """ - Run advanced chatbot draft workflow - """ - # fetch draft workflow by app_model - draft_workflow = self.get_draft_workflow(app_model=app_model) - - if not draft_workflow: - raise ValueError('Workflow not initialized') - - # run draft workflow - app_generator = AdvancedChatAppGenerator() - response = app_generator.generate( - app_model=app_model, - workflow=draft_workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=True - ) - - return response - - def run_draft_workflow(self, app_model: App, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom) -> Union[dict, Generator]: - # fetch draft workflow by app_model - draft_workflow = self.get_draft_workflow(app_model=app_model) - - if not draft_workflow: - raise ValueError('Workflow not initialized') - - # run draft workflow - app_generator = WorkflowAppGenerator() - response = app_generator.generate( - app_model=app_model, - workflow=draft_workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=True - ) - - return response - - def stop_workflow_task(self, task_id: str, - user: Union[Account, EndUser], - invoke_from: InvokeFrom) -> None: - """ - Stop workflow task - """ - AppQueueManager.set_stop_flag(task_id, invoke_from, user.id) - def run_draft_workflow_node(self, app_model: App, node_id: str, user_inputs: dict,