From 381b3d5016d0f8672b81a0d91fbc5f544d2690d0 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 16:55:59 +0800 Subject: [PATCH 001/160] optimize get app model to wraps --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/__init__.py | 21 ---- api/controllers/console/app/app.py | 100 +++++++----------- api/controllers/console/app/audio.py | 23 ++-- api/controllers/console/app/completion.py | 36 ++----- api/controllers/console/app/conversation.py | 59 ++++------- api/controllers/console/app/message.py | 64 ++++------- api/controllers/console/app/model_config.py | 17 ++- api/controllers/console/app/site.py | 14 +-- api/controllers/console/app/statistic.py | 38 +++---- api/controllers/console/app/workflow.py | 20 ++++ api/controllers/console/app/wraps.py | 55 ++++++++++ api/core/app_runner/basic_app_runner.py | 4 +- api/core/entities/application_entities.py | 20 ++++ api/core/prompt/prompt_transform.py | 20 +--- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 2 +- 17 files changed, 232 insertions(+), 265 deletions(-) create mode 100644 api/controllers/console/app/workflow.py create mode 100644 api/controllers/console/app/wraps.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ecfdc38612ce64..934b19116b1f80 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ from . import admin, apikey, extension, feature, setup, version # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic) + model_config, site, statistic, workflow) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py index b0b07517f10aad..e69de29bb2d1d6 100644 --- a/api/controllers/console/app/__init__.py +++ b/api/controllers/console/app/__init__.py @@ -1,21 +0,0 @@ -from controllers.console.app.error import AppUnavailableError -from extensions.ext_database import db -from flask_login import current_user -from models.model import App -from werkzeug.exceptions import NotFound - - -def _get_app(app_id, mode=None): - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() - - if not app: - raise NotFound("App not found") - - if mode and app.mode != mode: - raise NotFound("The {} app not found".format(mode)) - - return app diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ff974054155f22..c366ace93a2678 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,7 +9,8 @@ from constants.languages import demo_model_templates, languages from constants.model_template import model_templates from controllers.console import api -from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -31,13 +32,6 @@ from core.tools.tool_manager import ToolManager from core.entities.application_entities import AgentToolEntity -def _get_app(app_id, tenant_id): - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() - if not app: - raise AppNotFoundError - return app - - class AppListApi(Resource): @setup_required @@ -234,14 +228,12 @@ class AppApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields_with_site) - def get(self, app_id): + def get(self, app_model): """Get app detail""" - app_id = str(app_id) - app: App = _get_app(app_id, current_user.current_tenant_id) - # get original app model config - model_config: AppModelConfig = app.app_model_config + model_config: AppModelConfig = app_model.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get('tools') or []: @@ -277,27 +269,24 @@ def get(self, app_id): # override agent mode model_config.agent_mode = json.dumps(agent_mode) - return app + return app_model @setup_required @login_required @account_initialization_required - def delete(self, app_id): + @get_app_model + def delete(self, app_model): """Delete app""" - app_id = str(app_id) - if not current_user.is_admin_or_owner: raise Forbidden() - app = _get_app(app_id, current_user.current_tenant_id) - - db.session.delete(app) + db.session.delete(app_model) db.session.commit() # todo delete related data?? # model_config, site, api_token, conversation, message, message_feedback, message_annotation - app_was_deleted.send(app) + app_was_deleted.send(app_model) return {'result': 'success'}, 204 @@ -306,86 +295,77 @@ class AppNameApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() - app.name = args.get('name') - app.updated_at = datetime.utcnow() + app_model.name = args.get('name') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppIconApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') - app.updated_at = datetime.utcnow() + app_model.icon = args.get('icon') + app_model.icon_background = args.get('icon_background') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppSiteStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_site', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id).first() - if not app: - raise AppNotFoundError - if args.get('enable_site') == app.enable_site: - return app + if args.get('enable_site') == app_model.enable_site: + return app_model - app.enable_site = args.get('enable_site') - app.updated_at = datetime.utcnow() + app_model.enable_site = args.get('enable_site') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppApiStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_api', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) + if args.get('enable_api') == app_model.enable_api: + return app_model - if args.get('enable_api') == app.enable_api: - return app - - app.enable_api = args.get('enable_api') - app.updated_at = datetime.utcnow() + app_model.enable_api = args.get('enable_api') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppCopy(Resource): @@ -415,16 +395,14 @@ def create_app_model_config_copy(app_config, copy_app_id): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - - copy_app = self.create_app_copy(app) + def post(self, app_model): + copy_app = self.create_app_copy(app_model) db.session.add(copy_app) app_config = db.session.query(AppModelConfig). \ - filter(AppModelConfig.app_id == app_id). \ + filter(AppModelConfig.app_id == app_model.id). \ one_or_none() if app_config: diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 77eaf136fc644f..daa5570f9aacea 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,7 +6,6 @@ import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -18,8 +17,10 @@ ProviderQuotaExceededError, UnsupportedAudioTypeError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -36,10 +37,8 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model): file = request.files['file'] try: @@ -80,10 +79,8 @@ class ChatMessageTextApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, None) - + @get_app_model + def post(self, app_model): try: response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, @@ -120,9 +117,11 @@ def post(self, app_id): class TextModesApi(Resource): - def get(self, app_id: str): - app_model = _get_app(str(app_id)) - + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): try: parser = reqparse.RequestParser() parser.add_argument('language', type=str, required=True, location='args') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index f01d2afa031699..f378f7b2180c2e 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -10,7 +10,6 @@ import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -19,10 +18,11 @@ ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.entities.application_entities import InvokeFrom, AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -36,12 +36,8 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.WORKFLOW) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') @@ -93,12 +89,8 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.WORKFLOW) + def post(self, app_model, task_id): account = flask_login.current_user ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) @@ -110,12 +102,8 @@ class ChatMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') @@ -179,12 +167,8 @@ class ChatMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model, task_id): account = flask_login.current_user ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 452b0fddf6c014..4ee1ee40359993 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -9,9 +9,10 @@ from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from extensions.ext_database import db from fields.conversation_fields import ( conversation_detail_fields, @@ -29,10 +30,9 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(conversation_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -43,10 +43,7 @@ def get(self, app_id): parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'completion') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') if args['keyword']: query = query.join( @@ -106,24 +103,22 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(conversation_message_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'completion') + return _get_conversation(app_model, conversation_id) @setup_required @login_required @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.CHAT) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -139,10 +134,9 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.CHAT) @marshal_with(conversation_with_summary_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -154,10 +148,7 @@ def get(self, app_id): parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'chat') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'chat') if args['keyword']: query = query.join( @@ -228,25 +219,22 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.CHAT) @marshal_with(conversation_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'chat') + return _get_conversation(app_model, conversation_id) @setup_required @login_required + @get_app_model(mode=AppMode.CHAT) @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - # get app info - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -263,12 +251,9 @@ def delete(self, app_id, conversation_id): api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') -def _get_conversation(app_id, conversation_id, mode): - # get app info - app = _get_app(app_id, mode) - +def _get_conversation(app_model, conversation_id): conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 0064dbe663b90f..360602b9c2a98e 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -10,7 +10,6 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -18,9 +17,10 @@ ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import InvokeFrom +from core.entities.application_entities import InvokeFrom, AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -46,14 +46,10 @@ class ChatMessageListApi(Resource): @setup_required @login_required + @get_app_model(mode=AppMode.CHAT) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id, 'chat') - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args') @@ -62,7 +58,7 @@ def get(self, app_id): conversation = db.session.query(Conversation).filter( Conversation.id == args['conversation_id'], - Conversation.app_id == app.id + Conversation.app_id == app_model.id ).first() if not conversation: @@ -110,12 +106,8 @@ class MessageFeedbackApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('message_id', required=True, type=uuid_value, location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') @@ -125,7 +117,7 @@ def post(self, app_id): message = db.session.query(Message).filter( Message.id == message_id, - Message.app_id == app.id + Message.app_id == app_model.id ).first() if not message: @@ -141,7 +133,7 @@ def post(self, app_id): raise ValueError('rating cannot be None when feedback not exists') else: feedback = MessageFeedback( - app_id=app.id, + app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=args['rating'], @@ -160,21 +152,20 @@ class MessageAnnotationApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check('annotation') + @get_app_model @marshal_with(annotation_fields) - def post(self, app_id): + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - app_id = str(app_id) - parser = reqparse.RequestParser() parser.add_argument('message_id', required=False, type=uuid_value, location='json') parser.add_argument('question', required=True, type=str, location='json') parser.add_argument('answer', required=True, type=str, location='json') parser.add_argument('annotation_reply', required=False, type=dict, location='json') args = parser.parse_args() - annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) return annotation @@ -183,14 +174,10 @@ class MessageAnnotationCountApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def get(self, app_model): count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app.id + MessageAnnotation.app_id == app_model.id ).count() return {'count': count} @@ -200,8 +187,8 @@ class MessageMoreLikeThisApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.COMPLETION) + def get(self, app_model, message_id): message_id = str(message_id) parser = reqparse.RequestParser() @@ -211,9 +198,6 @@ def get(self, app_id, message_id): streaming = args['response_mode'] == 'streaming' - # get app info - app_model = _get_app(app_id, 'completion') - try: response = CompletionService.generate_more_like_this( app_model=app_model, @@ -257,13 +241,10 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.CHAT) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id, 'chat') - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, @@ -294,14 +275,11 @@ class MessageApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(message_detail_fields) - def get(self, app_id, message_id): - app_id = str(app_id) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id) - message = db.session.query(Message).filter( Message.id == message_id, Message.app_id == app_model.id diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 2095bb6bea4c2f..0f8bc28f6fe14f 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -5,7 +5,7 @@ from flask_restful import Resource from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.entities.application_entities import AgentToolEntity @@ -23,22 +23,19 @@ class ModelConfigResource(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + @get_app_model + def post(self, app_model): """Modify app model config""" - app_id = str(app_id) - - app = _get_app(app_id) - # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, account=current_user, config=request.json, - app_mode=app.mode + app_mode=app_model.mode ) new_app_model_config = AppModelConfig( - app_id=app.id, + app_id=app_model.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) @@ -130,11 +127,11 @@ def post(self, app_id): db.session.add(new_app_model_config) db.session.flush() - app.app_model_config_id = new_app_model_config.id + app_model.app_model_config_id = new_app_model_config.id db.session.commit() app_model_config_was_updated.send( - app, + app_model, app_model_config=new_app_model_config ) diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 4e9d9ed9b45682..256824981e6c72 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -4,7 +4,7 @@ from constants.languages import supported_language from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db @@ -34,13 +34,11 @@ class AppSite(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): + def post(self, app_model): args = parse_app_site_args() - app_id = str(app_id) - app_model = _get_app(app_id) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() @@ -82,11 +80,9 @@ class AppSiteAccessTokenReset(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id) - + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 7aed7da404aba7..e3bc44d6e93a2d 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -7,9 +7,10 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required @@ -20,10 +21,9 @@ class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -81,10 +81,9 @@ class DailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -141,10 +140,9 @@ class DailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -205,10 +203,9 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.CHAT) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -271,10 +268,9 @@ class UserSatisfactionRateStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -334,10 +330,9 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.WORKFLOW) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'completion') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -396,10 +391,9 @@ class TokensPerSecondStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py new file mode 100644 index 00000000000000..5a08e31c164c84 --- /dev/null +++ b/api/controllers/console/app/workflow.py @@ -0,0 +1,20 @@ +from flask_restful import Resource + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode +from libs.login import login_required + + +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW]) + def post(self, app_model): + return 'success', 200 + + +api.add_resource(DefaultBlockConfigApi, '/apps//default-workflow-block-configs') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py new file mode 100644 index 00000000000000..b3aca51871390c --- /dev/null +++ b/api/controllers/console/app/wraps.py @@ -0,0 +1,55 @@ +from functools import wraps +from typing import Union, Optional, Callable + +from controllers.console.app.error import AppNotFoundError +from core.entities.application_entities import AppMode +from extensions.ext_database import db +from libs.login import current_user +from models.model import App + + +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): + if not kwargs.get('app_id'): + raise ValueError('missing app_id in path parameters') + + app_id = kwargs.get('app_id') + app_id = str(app_id) + + del kwargs['app_id'] + + app_model = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app_model: + raise AppNotFoundError() + + app_mode = AppMode.value_of(app_model.mode) + if mode is not None: + if isinstance(mode, list): + modes = mode + else: + modes = [mode] + + # [temp] if workflow is in the mode list, then completion should be in the mode list + if AppMode.WORKFLOW in modes: + modes.append(AppMode.COMPLETION) + + if app_mode not in modes: + mode_values = {m.value for m in modes} + raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + + kwargs['app_model'] = app_model + + return view_func(*args, **kwargs) + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d3c91337c8f5c1..d1e16f860ca05a 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -4,12 +4,12 @@ from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity +from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity, \ + AppMode from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException -from core.prompt.prompt_transform import AppMode from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index abcf605c92d961..d3231affb28819 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -9,6 +9,26 @@ from core.model_runtime.entities.model_entities import AIModelEntity +class AppMode(Enum): + COMPLETION = 'completion' # will be deprecated in the future + WORKFLOW = 'workflow' # instead of 'completion' + CHAT = 'chat' + AGENT = 'agent' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class ModelConfigEntity(BaseModel): """ Model Config Entity. diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 0a373b7c42f8bd..08d94661b76208 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -7,7 +7,7 @@ from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, ModelConfigEntity, - PromptTemplateEntity, + PromptTemplateEntity, AppMode, ) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -25,24 +25,6 @@ from core.prompt.prompt_template import PromptTemplateParser -class AppMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - class ModelMode(enum.Enum): COMPLETION = 'completion' CHAT = 'chat' diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d52f6e20c219a8..3cf58d8e09be2f 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,6 +1,7 @@ import copy +from core.entities.application_entities import AppMode from core.prompt.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, @@ -13,7 +14,6 @@ COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) -from core.prompt.prompt_transform import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 2e21e562665938..ccfb101405cdf5 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -2,11 +2,11 @@ import uuid from core.entities.agent_entities import PlanningStrategy +from core.entities.application_entities import AppMode from core.external_data_tool.factory import ExternalDataToolFactory from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.moderation.factory import ModerationFactory -from core.prompt.prompt_transform import AppMode from core.provider_manager import ProviderManager from models.account import Account from services.dataset_service import DatasetService From d430136f656606bf8c7bb1c3bed7492d4b901dfb Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 16:56:29 +0800 Subject: [PATCH 002/160] lint --- api/controllers/console/app/completion.py | 2 +- api/controllers/console/app/message.py | 2 +- api/controllers/console/app/wraps.py | 3 ++- api/core/app_runner/basic_app_runner.py | 9 +++++++-- api/core/prompt/prompt_transform.py | 3 ++- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index f378f7b2180c2e..381d0bbb6b5930 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,7 +22,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom, AppMode +from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 360602b9c2a98e..5d4f6b7e262d67 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -20,7 +20,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import InvokeFrom, AppMode +from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index b3aca51871390c..fe2b408702d07d 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import Union, Optional, Callable +from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from core.entities.application_entities import AppMode diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d1e16f860ca05a..d87302c717ca87 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -4,8 +4,13 @@ from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity, \ - AppMode +from core.entities.application_entities import ( + ApplicationGenerateEntity, + AppMode, + DatasetEntity, + InvokeFrom, + ModelConfigEntity, +) from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 08d94661b76208..4bf96ce2657fe9 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -6,8 +6,9 @@ from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, + AppMode, ModelConfigEntity, - PromptTemplateEntity, AppMode, + PromptTemplateEntity, ) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory From b7c6cba23f24625f41a5446abcec6e210354f04d Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 20:48:54 +0800 Subject: [PATCH 003/160] add workflow models --- api/controllers/console/app/workflow.py | 21 +- .../versions/b289e2408ee2_add_workflow.py | 143 +++++++++++ api/models/model.py | 20 +- api/models/workflow.py | 237 ++++++++++++++++++ 4 files changed, 415 insertions(+), 6 deletions(-) create mode 100644 api/migrations/versions/b289e2408ee2_add_workflow.py create mode 100644 api/models/workflow.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5a08e31c164c84..4acdb4943d1316 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -12,9 +12,20 @@ class DefaultBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW]) - def post(self, app_model): - return 'success', 200 + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('app_mode', type=str, required=True, nullable=False, + choices=[AppMode.CHAT.value, AppMode.WORKFLOW.value], location='args') + args = parser.parse_args() + app_mode = args.get('app_mode') + app_mode = AppMode.value_of(app_mode) -api.add_resource(DefaultBlockConfigApi, '/apps//default-workflow-block-configs') + # TODO: implement this + + return { + "blocks": [] + } + + +api.add_resource(DefaultBlockConfigApi, '/default-workflow-block-configs') diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py new file mode 100644 index 00000000000000..52168a04e764ab --- /dev/null +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -0,0 +1,143 @@ +"""add workflow + +Revision ID: b289e2408ee2 +Revises: 16830a790f0f +Create Date: 2024-02-19 12:47:24.646954 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b289e2408ee2' +down_revision = '16830a790f0f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_app_logs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=False), + sa.Column('created_from', sa.String(length=255), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_app_log_pkey') + ) + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.create_index('workflow_app_log_app_idx', ['tenant_id', 'app_id'], unique=False) + + op.create_table('workflow_node_executions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=True), + sa.Column('index', sa.Integer(), nullable=False), + sa.Column('predecessor_node_id', sa.String(length=255), nullable=True), + sa.Column('node_id', sa.String(length=255), nullable=False), + sa.Column('node_type', sa.String(length=255), nullable=False), + sa.Column('title', sa.String(length=255), nullable=False), + sa.Column('inputs', sa.Text(), nullable=False), + sa.Column('process_data', sa.Text(), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('execution_metadata', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey') + ) + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.create_index('workflow_node_execution_node_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'], unique=False) + batch_op.create_index('workflow_node_execution_workflow_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'workflow_run_id'], unique=False) + + op.create_table('workflow_runs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('sequence_number', sa.Integer(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), + sa.Column('currency', sa.String(length=255), nullable=True), + sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_run_pkey') + ) + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from'], unique=False) + + op.create_table('workflows', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', postgresql.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_pkey') + ) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'type', 'version'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_run_id', postgresql.UUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('workflow_run_id') + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('workflow_id') + batch_op.drop_column('chatbot_app_engine') + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_index('workflow_version_idx') + + op.drop_table('workflows') + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_triggerd_from_idx') + + op.drop_table('workflow_runs') + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_workflow_run_idx') + batch_op.drop_index('workflow_node_execution_node_run_idx') + + op.drop_table('workflow_node_executions') + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_app_log_app_idx') + + op.drop_table('workflow_app_logs') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 8776f896730a07..6e7a58ed457213 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -12,6 +12,7 @@ from libs.helper import generate_string from .account import Account, Tenant +from .workflow import WorkflowRun, Workflow class DifySetup(db.Model): @@ -156,12 +157,14 @@ class AppModelConfig(db.Model): agent_mode = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text) retriever_resource = db.Column(db.Text) - prompt_type = db.Column(db.String(255), nullable=False, default='simple') + prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) chat_prompt_config = db.Column(db.Text) completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) + chatbot_app_engine = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + workflow_id = db.Column(UUID) @property def app(self): @@ -261,6 +264,13 @@ def file_upload_dict(self) -> dict: "image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} + @property + def workflow(self): + if self.workflow_id: + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None + def to_dict(self) -> dict: return { "provider": "", @@ -581,6 +591,7 @@ class Message(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + workflow_run_id = db.Column(UUID) @property def user_feedback(self): @@ -679,6 +690,13 @@ def files(self): return files + @property + def workflow_run(self): + if self.workflow_run_id: + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + + return None + class MessageFeedback(db.Model): __tablename__ = 'message_feedbacks' diff --git a/api/models/workflow.py b/api/models/workflow.py new file mode 100644 index 00000000000000..59b8eeb6cdb642 --- /dev/null +++ b/api/models/workflow.py @@ -0,0 +1,237 @@ +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db + + +class Workflow(db.Model): + """ + Workflow, for `Workflow App` and `Chat App workflow mode`. + + Attributes: + + - id (uuid) Workflow ID, pk + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - type (string) Workflow type + + `workflow` for `Workflow App` + + `chat` for `Chat App workflow mode` + + - version (string) Version + + `draft` for draft version (only one for each app), other for version number (redundant) + + - graph (text) Workflow canvas configuration (JSON) + + The entire canvas configuration JSON, including Node, Edge, and other configurations + + - nodes (array[object]) Node list, see Node Schema + + - edges (array[object]) Edge list, see Edge Schema + + - created_by (uuid) Creator ID + - created_at (timestamp) Creation time + - updated_by (uuid) `optional` Last updater ID + - updated_at (timestamp) `optional` Last update time + """ + + __tablename__ = 'workflows' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_pkey'), + db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'type', 'version'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + version = db.Column(db.String(255), nullable=False) + graph = db.Column(db.Text) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(UUID) + updated_at = db.Column(db.DateTime) + + +class WorkflowRun(db.Model): + """ + Workflow Run + + Attributes: + + - id (uuid) Run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID + - type (string) Workflow type + - triggered_from (string) Trigger source + + `debugging` for canvas debugging + + `app-run` for (published) app execution + + - version (string) Version + - graph (text) Workflow canvas configuration (JSON) + - inputs (text) Input parameters + - status (string) Execution status, `running` / `succeeded` / `failed` + - outputs (text) `optional` Output content + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - total_tokens (int) `optional` Total tokens used + - total_price (decimal) `optional` Total cost + - currency (string) `optional` Currency, such as USD / RMB + - total_steps (int) Total steps (redundant), default 0 + - created_by (uuid) Runner ID + - created_at (timestamp) Run time + - finished_at (timestamp) End time + """ + + __tablename__ = 'workflow_runs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), + db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + sequence_number = db.Column(db.Integer, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + triggered_from = db.Column(db.String(255), nullable=False) + version = db.Column(db.String(255), nullable=False) + graph = db.Column(db.Text) + inputs = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) + outputs = db.Column(db.Text) + error = db.Column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + total_price = db.Column(db.Numeric(10, 7)) + currency = db.Column(db.String(255)) + total_steps = db.Column(db.Integer, server_default=db.text('0')) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + finished_at = db.Column(db.DateTime) + + +class WorkflowNodeExecution(db.Model): + """ + Workflow Node Execution + + - id (uuid) Execution ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Workflow ID + - triggered_from (string) Trigger source + + `single-step` for single-step debugging + + `workflow-run` for workflow execution (debugging / user execution) + + - workflow_run_id (uuid) `optional` Workflow run ID + + Null for single-step debugging. + + - index (int) Execution sequence number, used for displaying Tracing Node order + - predecessor_node_id (string) `optional` Predecessor node ID, used for displaying execution path + - node_id (string) Node ID + - node_type (string) Node type, such as `start` + - title (string) Node title + - inputs (json) All predecessor node variable content used in the node + - process_data (json) Node process data + - outputs (json) `optional` Node output variables + - status (string) Execution status, `running` / `succeeded` / `failed` + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - execution_metadata (text) Metadata + + - total_tokens (int) `optional` Total tokens used + + - total_price (decimal) `optional` Total cost + + - currency (string) `optional` Currency, such as USD / RMB + + - created_at (timestamp) Run time + - created_by (uuid) Runner ID + - finished_at (timestamp) End time + """ + + __tablename__ = 'workflow_node_executions' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), + db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'workflow_run_id'), + db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'node_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + triggered_from = db.Column(db.String(255), nullable=False) + workflow_run_id = db.Column(UUID) + index = db.Column(db.Integer, nullable=False) + predecessor_node_id = db.Column(db.String(255)) + node_id = db.Column(db.String(255), nullable=False) + node_type = db.Column(db.String(255), nullable=False) + title = db.Column(db.String(255), nullable=False) + inputs = db.Column(db.Text, nullable=False) + process_data = db.Column(db.Text, nullable=False) + outputs = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) + error = db.Column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + execution_metadata = db.Column(db.Text) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by = db.Column(UUID, nullable=False) + finished_at = db.Column(db.DateTime) + + +class WorkflowAppLog(db.Model): + """ + Workflow App execution log, excluding workflow debugging records. + + Attributes: + + - id (uuid) run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Associated Workflow ID + - workflow_run_id (uuid) Associated Workflow Run ID + - created_from (string) Creation source + + `service-api` App Execution OpenAPI + + `web-app` WebApp + + `installed-app` Installed App + + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Creator ID, depends on the user table according to created_by_role + - created_at (timestamp) Creation time + """ + + __tablename__ = 'workflow_app_logs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), + db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + workflow_run_id = db.Column(UUID, nullable=False) + created_from = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) From 603b1e9ed49c7b4b43033b43dcb76db1ebe5d476 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 20:49:13 +0800 Subject: [PATCH 004/160] lint --- api/controllers/console/app/workflow.py | 1 - api/migrations/versions/b289e2408ee2_add_workflow.py | 2 +- api/models/model.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4acdb4943d1316..5689c0fd9276c4 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,7 +1,6 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.entities.application_entities import AppMode diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 52168a04e764ab..605c66bed1139b 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -5,8 +5,8 @@ Create Date: 2024-02-19 12:47:24.646954 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/models/model.py b/api/models/model.py index 6e7a58ed457213..2b44957b06cf1d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -12,7 +12,7 @@ from libs.helper import generate_string from .account import Account, Tenant -from .workflow import WorkflowRun, Workflow +from .workflow import Workflow, WorkflowRun class DifySetup(db.Model): From 3642dd3a7395a9c7b3a2ad3858bd89d6d089b772 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 20 Feb 2024 21:30:43 +0800 Subject: [PATCH 005/160] add workflow logics --- api/constants/model_template.py | 91 ++++-- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/app.py | 54 ++-- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 3 +- api/controllers/console/app/conversation.py | 3 +- api/controllers/console/app/error.py | 6 + api/controllers/console/app/message.py | 50 +--- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/app/workflow.py | 84 +++++- api/controllers/console/app/wraps.py | 21 +- api/controllers/console/explore/message.py | 47 ---- api/controllers/console/ping.py | 17 ++ api/controllers/console/workspace/account.py | 15 +- api/controllers/console/workspace/members.py | 21 +- api/controllers/web/message.py | 47 ---- api/core/app_runner/basic_app_runner.py | 4 +- api/core/application_manager.py | 34 ++- api/core/entities/application_entities.py | 55 ++-- api/core/prompt/prompt_transform.py | 2 +- api/core/workflow/__init__.py | 0 api/core/workflow/entities/NodeEntities.py | 32 +++ api/core/workflow/entities/__init__.py | 0 api/core/workflow/nodes/__init__.py | 0 api/core/workflow/nodes/end/__init__.py | 0 api/core/workflow/nodes/end/end_node.py | 0 api/core/workflow/nodes/end/entities.py | 25 ++ api/core/workflow/workflow_engine_manager.py | 0 api/fields/annotation_fields.py | 8 +- api/fields/conversation_fields.py | 13 +- api/fields/member_fields.py | 38 +++ api/fields/workflow_fields.py | 16 ++ .../versions/b289e2408ee2_add_workflow.py | 2 +- api/models/model.py | 29 +- api/models/workflow.py | 55 +++- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 19 +- api/services/completion_service.py | 60 +--- api/services/errors/__init__.py | 2 +- api/services/errors/app.py | 2 - api/services/workflow/__init__.py | 0 api/services/workflow/defaults.py | 72 +++++ api/services/workflow/workflow_converter.py | 259 ++++++++++++++++++ api/services/workflow_service.py | 83 ++++++ 44 files changed, 891 insertions(+), 386 deletions(-) create mode 100644 api/controllers/console/ping.py create mode 100644 api/core/workflow/__init__.py create mode 100644 api/core/workflow/entities/NodeEntities.py create mode 100644 api/core/workflow/entities/__init__.py create mode 100644 api/core/workflow/nodes/__init__.py create mode 100644 api/core/workflow/nodes/end/__init__.py create mode 100644 api/core/workflow/nodes/end/end_node.py create mode 100644 api/core/workflow/nodes/end/entities.py create mode 100644 api/core/workflow/workflow_engine_manager.py create mode 100644 api/fields/member_fields.py create mode 100644 api/fields/workflow_fields.py delete mode 100644 api/services/errors/app.py create mode 100644 api/services/workflow/__init__.py create mode 100644 api/services/workflow/defaults.py create mode 100644 api/services/workflow/workflow_converter.py create mode 100644 api/services/workflow_service.py diff --git a/api/constants/model_template.py b/api/constants/model_template.py index d87f7c392610f7..c22306ac87b17b 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,10 +1,10 @@ import json model_templates = { - # completion default mode - 'completion_default': { + # workflow default mode + 'workflow_default': { 'app': { - 'mode': 'completion', + 'mode': 'workflow', 'enable_site': True, 'enable_api': True, 'is_demo': False, @@ -15,24 +15,7 @@ 'model_config': { 'provider': '', 'model_id': '', - 'configs': {}, - 'model': json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": {} - }), - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + 'configs': {} } }, @@ -48,14 +31,70 @@ 'status': 'normal' }, 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {}, + 'provider': 'openai', + 'model_id': 'gpt-4', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, 'model': json.dumps({ "provider": "openai", - "name": "gpt-3.5-turbo", + "name": "gpt-4", "mode": "chat", - "completion_params": {} + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } + }) + } + }, + + # agent default mode + 'agent_default': { + 'app': { + 'mode': 'agent', + 'enable_site': True, + 'enable_api': True, + 'is_demo': False, + 'api_rpm': 0, + 'api_rph': 0, + 'status': 'normal' + }, + 'model_config': { + 'provider': 'openai', + 'model_id': 'gpt-4', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, + 'model': json.dumps({ + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } }) } }, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 934b19116b1f80..649df278ecd02c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -5,7 +5,7 @@ api = ExternalApi(bp) # Import other controllers -from . import admin, apikey, extension, feature, setup, version +from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, model_config, site, statistic, workflow) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c366ace93a2678..cf505bedb8aedc 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -26,7 +26,7 @@ template_list_fields, ) from libs.login import login_required -from models.model import App, AppModelConfig, Site +from models.model import App, AppModelConfig, Site, AppMode from services.app_model_config_service import AppModelConfigService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager @@ -80,7 +80,7 @@ def post(self): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json') + parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') @@ -90,18 +90,7 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - try: - provider_manager = ProviderManager() - default_model_entity = provider_manager.get_default_model( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except (ProviderTokenNotInitError, LLMBadRequestError): - default_model_entity = None - except Exception as e: - logging.exception(e) - default_model_entity = None - + # TODO: MOVE TO IMPORT API if args['model_config'] is not None: # validate config model_config_dict = args['model_config'] @@ -150,27 +139,30 @@ def post(self): if 'mode' not in args or args['mode'] is None: abort(400, message="mode is required") - model_config_template = model_templates[args['mode'] + '_default'] + app_mode = AppMode.value_of(args['mode']) + + model_config_template = model_templates[app_mode.value + '_default'] app = App(**model_config_template['app']) app_model_config = AppModelConfig(**model_config_template['model_config']) - # get model provider - model_manager = ModelManager() - - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None - - if model_instance: - model_dict = app_model_config.model_dict - model_dict['provider'] = model_instance.provider - model_dict['name'] = model_instance.model - app_model_config.model = json.dumps(model_dict) + if app_mode in [AppMode.CHAT, AppMode.AGENT]: + # get model provider + model_manager = ModelManager() + + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None + + if model_instance: + model_dict = app_model_config.model_dict + model_dict['provider'] = model_instance.provider + model_dict['name'] = model_instance.model + app_model_config.model = json.dumps(model_dict) app.name = args['name'] app.mode = args['mode'] diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index daa5570f9aacea..458fa5098f80a8 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -20,10 +20,10 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required +from models.model import AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 381d0bbb6b5930..11fdba177d6734 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,11 +22,12 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import AppMode, InvokeFrom +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from libs.login import login_required +from models.model import AppMode from services.completion_service import CompletionService diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 4ee1ee40359993..5d312149f714bb 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -12,7 +12,6 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from extensions.ext_database import db from fields.conversation_fields import ( conversation_detail_fields, @@ -22,7 +21,7 @@ ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation +from models.model import Conversation, Message, MessageAnnotation, AppMode class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index d7b31906c8de21..b1abb38248f5a2 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException): error_code = 'too_many_files' description = "Only one file is allowed." code = 400 + + +class DraftWorkflowNotExist(BaseHTTPException): + error_code = 'draft_workflow_not_exist' + description = "Draft workflow need to be initialized." + code = 400 diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5d4f6b7e262d67..9a177116eacbc1 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -11,7 +11,6 @@ from controllers.console import api from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -20,7 +19,6 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -28,10 +26,8 @@ from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode from services.annotation_service import AppAnnotationService -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService @@ -183,49 +179,6 @@ def get(self, app_model): return {'count': count} -class MessageMoreLikeThisApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=AppMode.COMPLETION) - def get(self, app_model, message_id): - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], - location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -291,7 +244,6 @@ def get(self, app_model, message_id): return message -api.add_resource(MessageMoreLikeThisApi, '/apps//completion-messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') api.add_resource(MessageFeedbackApi, '/apps//feedbacks') diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index e3bc44d6e93a2d..ea4d5971127def 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -10,10 +10,10 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required +from models.model import AppMode class DailyConversationStatistic(Resource): diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5689c0fd9276c4..2794735bbb55a9 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,30 +1,88 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse, marshal_with from controllers.console import api +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode -from libs.login import login_required +from fields.workflow_fields import workflow_fields +from libs.login import login_required, current_user +from models.model import App, ChatbotAppEngine, AppMode +from services.workflow_service import WorkflowService -class DefaultBlockConfigApi(Resource): +class DraftWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise DraftWorkflowNotExist() + + # return workflow, if not found, return None (initiate graph by frontend) + return workflow + @setup_required @login_required @account_initialization_required - def get(self): + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + def post(self, app_model: App): + """ + Sync draft workflow + """ parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, nullable=False, - choices=[AppMode.CHAT.value, AppMode.WORKFLOW.value], location='args') + parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() - app_mode = args.get('app_mode') - app_mode = AppMode.value_of(app_mode) - - # TODO: implement this + workflow_service = WorkflowService() + workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) return { - "blocks": [] + "result": "success" } -api.add_resource(DefaultBlockConfigApi, '/default-workflow-block-configs') +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + def get(self, app_model: App): + """ + Get default block config + """ + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_configs() + + +class ConvertToWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.CHAT) + @marshal_with(workflow_fields) + def post(self, app_model: App): + """ + Convert basic mode of chatbot app to workflow + """ + # convert to workflow mode + workflow_service = WorkflowService() + workflow = workflow_service.chatbot_convert_to_workflow(app_model=app_model) + + # return workflow + return workflow + + +api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index fe2b408702d07d..fe35e723043469 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -3,13 +3,14 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError -from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.login import current_user -from models.model import App +from models.model import App, ChatbotAppEngine, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, + mode: Union[AppMode, list[AppMode]] = None, + app_engine: ChatbotAppEngine = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): @@ -37,14 +38,20 @@ def decorated_view(*args, **kwargs): else: modes = [mode] - # [temp] if workflow is in the mode list, then completion should be in the mode list - if AppMode.WORKFLOW in modes: - modes.append(AppMode.COMPLETION) - if app_mode not in modes: mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + if app_engine is not None: + if app_mode not in [AppMode.CHAT, AppMode.WORKFLOW]: + raise AppNotFoundError(f"App mode is not supported for {app_engine.value} app engine.") + + if app_mode == AppMode.CHAT: + # fetch current app model config + app_model_config = app_model.app_model_config + if not app_model_config or app_model_config.chatbot_app_engine != app_engine.value: + raise AppNotFoundError(f"{app_engine.value} app engine is not supported.") + kwargs['app_model'] = app_model return view_func(*args, **kwargs) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425fa896..bef26b4d994d8f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -12,7 +12,6 @@ import services from controllers.console import api from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -24,13 +23,10 @@ NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -76,48 +72,6 @@ def post(self, installed_app, message_id): return {'result': 'success'} -class MessageMoreLikeThisApi(InstalledAppResource): - def get(self, installed_app, message_id): - app_model = installed_app.app - if app_model.mode != 'completion': - raise NotCompletionAppError() - - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -166,5 +120,4 @@ def get(self, installed_app, message_id): api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py new file mode 100644 index 00000000000000..7664ba8c165db8 --- /dev/null +++ b/api/controllers/console/ping.py @@ -0,0 +1,17 @@ +from flask_restful import Resource + +from controllers.console import api + + +class PingApi(Resource): + + def get(self): + """ + For connection health check + """ + return { + "result": "pong" + } + + +api.add_resource(PingApi, '/ping') diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index b7cfba9d04c6b1..656a4d4cee6af5 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -16,26 +16,13 @@ ) from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField -} - class AccountInitApi(Resource): diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf57cd4b24c333..f40ccebf25496a 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,33 +1,18 @@ from flask import current_app from flask_login import current_user -from flask_restful import Resource, abort, fields, marshal_with, reqparse +from flask_restful import Resource, abort, marshal_with, reqparse import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from libs.helper import TimestampField +from fields.member_fields import account_with_role_list_fields from libs.login import login_required from models.account import Account from services.account_service import RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, -} - -account_list_fields = { - 'accounts': fields.List(fields.Nested(account_fields)) -} - class MemberListApi(Resource): """List all members of current tenant.""" @@ -35,7 +20,7 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_list_fields) + @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) return {'result': 'success', 'accounts': members}, 200 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index e03bdd63bb2a27..5120f49c5ecf95 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -11,7 +11,6 @@ import services from controllers.web import api from controllers.web.error import ( - AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, CompletionRequestError, NotChatAppError, @@ -21,14 +20,11 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs.helper import TimestampField, uuid_value -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -113,48 +109,6 @@ def post(self, app_model, end_user, message_id): return {'result': 'success'} -class MessageMoreLikeThisApi(WebApiResource): - def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': - raise NotCompletionAppError() - - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming - ) - - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -202,5 +156,4 @@ def get(self, app_model, end_user, message_id): api.add_resource(MessageListApi, '/messages') api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d87302c717ca87..26e9cc84aa770e 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -6,7 +6,6 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - AppMode, DatasetEntity, InvokeFrom, ModelConfigEntity, @@ -16,7 +15,7 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationException from extensions.ext_database import db -from models.model import App, Conversation, Message +from models.model import App, Conversation, Message, AppMode logger = logging.getLogger(__name__) @@ -250,6 +249,7 @@ def retrieve_dataset_context(self, tenant_id: str, invoke_from ) + # TODO if (app_record.mode == AppMode.COMPLETION.value and dataset_config and dataset_config.retrieve_config.query_variable): query = inputs.get(dataset_config.retrieve_config.query_variable, "") diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 9aca61c7bb40f2..2fde422d4726b8 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -28,7 +28,7 @@ ModelConfigEntity, PromptTemplateEntity, SensitiveWordAvoidanceEntity, - TextToSpeechEntity, + TextToSpeechEntity, VariableEntity, ) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -93,7 +93,7 @@ def generate(self, tenant_id: str, app_id=app_id, app_model_config_id=app_model_config_id, app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self._convert_from_app_model_config_dict( + app_orchestration_config_entity=self.convert_from_app_model_config_dict( tenant_id=tenant_id, app_model_config_dict=app_model_config_dict ), @@ -234,7 +234,7 @@ def _handle_response(self, application_generate_entity: ApplicationGenerateEntit logger.exception(e) raise e - def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ + def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ -> AppOrchestrationConfigEntity: """ Convert app model config dict to entity. @@ -384,8 +384,10 @@ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_d config=external_data_tool['config'] ) ) + + properties['variables'] = [] - # current external_data_tools + # variables and external_data_tools for variable in copy_app_model_config_dict.get('user_input_form', []): typ = list(variable.keys())[0] if typ == 'external_data_tool': @@ -397,6 +399,30 @@ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_d config=val['config'] ) ) + elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]: + properties['variables'].append( + VariableEntity( + type=VariableEntity.Type.TEXT_INPUT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + max_length=variable[typ].get('max_length'), + default=variable[typ].get('default'), + ) + ) + elif typ == VariableEntity.Type.SELECT.value: + properties['variables'].append( + VariableEntity( + type=VariableEntity.Type.SELECT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + options=variable[typ].get('options'), + default=variable[typ].get('default'), + ) + ) # show retrieve source show_retrieve_source = False diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index d3231affb28819..092591a73fbe00 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -9,26 +9,6 @@ from core.model_runtime.entities.model_entities import AIModelEntity -class AppMode(Enum): - COMPLETION = 'completion' # will be deprecated in the future - WORKFLOW = 'workflow' # instead of 'completion' - CHAT = 'chat' - AGENT = 'agent' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - class ModelConfigEntity(BaseModel): """ Model Config Entity. @@ -106,6 +86,38 @@ def value_of(cls, value: str) -> 'PromptType': advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntity(BaseModel): + """ + Variable Entity. + """ + class Type(Enum): + TEXT_INPUT = 'text-input' + SELECT = 'select' + PARAGRAPH = 'paragraph' + + @classmethod + def value_of(cls, value: str) -> 'VariableEntity.Type': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid variable type value {value}') + + variable: str + label: str + description: Optional[str] = None + type: Type + required: bool = False + max_length: Optional[int] = None + options: Optional[list[str]] = None + default: Optional[str] = None + + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. @@ -245,6 +257,7 @@ class AppOrchestrationConfigEntity(BaseModel): """ model_config: ModelConfigEntity prompt_template: PromptTemplateEntity + variables: list[VariableEntity] = [] external_data_variables: list[ExternalDataVariableEntity] = [] agent: Optional[AgentEntity] = None @@ -256,7 +269,7 @@ class AppOrchestrationConfigEntity(BaseModel): show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: dict = {} + text_to_speech: Optional[TextToSpeechEntity] = None sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4bf96ce2657fe9..abbfa962494b33 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -6,7 +6,6 @@ from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, - AppMode, ModelConfigEntity, PromptTemplateEntity, ) @@ -24,6 +23,7 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser +from models.model import AppMode class ModelMode(enum.Enum): diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/NodeEntities.py new file mode 100644 index 00000000000000..d72b000dfb876d --- /dev/null +++ b/api/core/workflow/entities/NodeEntities.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class NodeType(Enum): + """ + Node Types. + """ + START = 'start' + END = 'end' + DIRECT_ANSWER = 'direct-answer' + LLM = 'llm' + KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' + IF_ELSE = 'if-else' + CODE = 'code' + TEMPLATE_TRANSFORM = 'template-transform' + QUESTION_CLASSIFIER = 'question-classifier' + HTTP_REQUEST = 'http-request' + TOOL = 'tool' + VARIABLE_ASSIGNER = 'variable-assigner' + + @classmethod + def value_of(cls, value: str) -> 'BlockType': + """ + Get value of given block type. + + :param value: block type value + :return: block type + """ + for block_type in cls: + if block_type.value == value: + return block_type + raise ValueError(f'invalid block type value {value}') diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py new file mode 100644 index 00000000000000..045e7effc49711 --- /dev/null +++ b/api/core/workflow/nodes/end/entities.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class EndNodeOutputType(Enum): + """ + END Node Output Types. + + none, plain-text, structured + """ + NONE = 'none' + PLAIN_TEXT = 'plain-text' + STRUCTURED = 'structured' + + @classmethod + def value_of(cls, value: str) -> 'OutputType': + """ + Get value of given output type. + + :param value: output type value + :return: output type + """ + for output_type in cls: + if output_type.value == value: + return output_type + raise ValueError(f'invalid output type value {value}') diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 5974de34de201b..d9cd6c03bb5b81 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,12 +2,6 @@ from libs.helper import TimestampField -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - annotation_fields = { "id": fields.String, @@ -15,7 +9,7 @@ "answer": fields.Raw(attribute='content'), "hit_count": fields.Integer, "created_at": TimestampField, - # 'account': fields.Nested(account_fields, allow_null=True) + # 'account': fields.Nested(simple_account_fields, allow_null=True) } annotation_list_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1adc836aa2479f..afa486f1cdfe6c 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,5 +1,6 @@ from flask_restful import fields +from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -8,31 +9,25 @@ def format(self, value): return value[0]['text'] if value else '' -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - feedback_fields = { 'rating': fields.String, 'content': fields.String, 'from_source': fields.String, 'from_end_user_id': fields.String, - 'from_account': fields.Nested(account_fields, allow_null=True), + 'from_account': fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { 'id': fields.String, 'question': fields.String, 'content': fields.String, - 'account': fields.Nested(account_fields, allow_null=True), + 'account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } annotation_hit_history_fields = { 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(account_fields, allow_null=True), + 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py new file mode 100644 index 00000000000000..79164b3848536d --- /dev/null +++ b/api/fields/member_fields.py @@ -0,0 +1,38 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +simple_account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'is_password_set': fields.Boolean, + 'interface_language': fields.String, + 'interface_theme': fields.String, + 'timezone': fields.String, + 'last_login_at': TimestampField, + 'last_login_ip': fields.String, + 'created_at': TimestampField +} + +account_with_role_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'last_login_at': TimestampField, + 'created_at': TimestampField, + 'role': fields.String, + 'status': fields.String, +} + +account_with_role_list_fields = { + 'accounts': fields.List(fields.Nested(account_with_role_fields)) +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py new file mode 100644 index 00000000000000..9dc92ea43bc03f --- /dev/null +++ b/api/fields/workflow_fields.py @@ -0,0 +1,16 @@ +import json + +from flask_restful import fields + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + + +workflow_fields = { + 'id': fields.String, + 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), + 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), + 'created_at': TimestampField, + 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), + 'updated_at': TimestampField +} diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 605c66bed1139b..e9cd2caf3ae3c5 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -102,7 +102,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_pkey') ) with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'type', 'version'], unique=False) + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) diff --git a/api/models/model.py b/api/models/model.py index 2b44957b06cf1d..58e29cd21c6634 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,5 +1,7 @@ import json import uuid +from enum import Enum +from typing import Optional from flask import current_app, request from flask_login import UserMixin @@ -25,6 +27,25 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) +class AppMode(Enum): + WORKFLOW = 'workflow' + CHAT = 'chat' + AGENT = 'agent' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class App(db.Model): __tablename__ = 'apps' __table_args__ = ( @@ -56,7 +77,7 @@ def site(self): return site @property - def app_model_config(self): + def app_model_config(self) -> Optional['AppModelConfig']: app_model_config = db.session.query(AppModelConfig).filter( AppModelConfig.id == self.app_model_config_id).first() return app_model_config @@ -130,6 +151,12 @@ def deleted_tools(self) -> list: return deleted_tools + +class ChatbotAppEngine(Enum): + NORMAL = 'normal' + WORKFLOW = 'workflow' + + class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( diff --git a/api/models/workflow.py b/api/models/workflow.py index 59b8eeb6cdb642..ed26e98896d286 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,43 @@ +from enum import Enum +from typing import Union + from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db +from models.account import Account +from models.model import AppMode + + +class WorkflowType(Enum): + """ + Workflow Type Enum + """ + WORKFLOW = 'workflow' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow type value {value}') + + @classmethod + def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType': + """ + Get workflow type from app mode. + + :param app_mode: app mode + :return: workflow type + """ + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) + return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT class Workflow(db.Model): @@ -39,7 +76,7 @@ class Workflow(db.Model): __tablename__ = 'workflows' __table_args__ = ( db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'type', 'version'), + db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) @@ -53,6 +90,14 @@ class Workflow(db.Model): updated_by = db.Column(UUID) updated_at = db.Column(db.DateTime) + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + class WorkflowRun(db.Model): """ @@ -116,6 +161,14 @@ class WorkflowRun(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) finished_at = db.Column(db.DateTime) + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + class WorkflowNodeExecution(db.Model): """ diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 3cf58d8e09be2f..1e893e0eca4cad 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,6 @@ import copy -from core.entities.application_entities import AppMode from core.prompt.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, @@ -14,6 +13,7 @@ COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) +from models.model import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index ccfb101405cdf5..3ac11c645c9940 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -9,6 +9,7 @@ from core.moderation.factory import ModerationFactory from core.provider_manager import ProviderManager from models.account import Account +from models.model import AppMode from services.dataset_service import DatasetService SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -315,9 +316,6 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, if "tool_parameters" not in tool: raise ValueError("tool_parameters is required in agent_mode.tools") - # dataset_query_variable - cls.is_dataset_query_variable_valid(config, app_mode) - # advanced prompt validation cls.is_advanced_prompt_valid(config, app_mode) @@ -443,21 +441,6 @@ def is_external_data_tools_valid(cls, tenant_id: str, config: dict): config=config ) - @classmethod - def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: - # Only check when mode is completion - if mode != 'completion': - return - - agent_mode = config.get("agent_mode", {}) - tools = agent_mode.get("tools", []) - dataset_exists = "dataset" in str(tools) - - dataset_query_variable = config.get("dataset_query_variable") - - if dataset_exists and not dataset_query_variable: - raise ValueError("Dataset query variable is required when dataset is exist") - @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type diff --git a/api/services/completion_service.py b/api/services/completion_service.py index cbfbe9ef416b63..5599c60113c3b5 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -8,12 +8,10 @@ from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message +from models.model import Account, App, AppModelConfig, Conversation, EndUser from services.app_model_config_service import AppModelConfigService -from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import MessageNotExistsError class CompletionService: @@ -157,62 +155,6 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, } ) - @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: - if not user: - raise ValueError('user cannot be None') - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() - - if not message: - raise MessageNotExistsError() - - current_app_model_config = app_model.app_model_config - more_like_this = current_app_model_config.more_like_this_dict - - if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: - raise MoreLikeThisDisabledError() - - app_model_config = message.app_model_config - model_dict = app_model_config.model_dict - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - app_model_config.model = json.dumps(model_dict) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.transform_message_files( - message.files, app_model_config - ) - - application_manager = ApplicationManager() - return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=True, - user=user, - invoke_from=invoke_from, - inputs=message.inputs, - query=message.query, - files=file_objs, - conversation=None, - stream=streaming, - extras={ - "auto_generate_conversation_name": False - } - ) - @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 5804f599fe63bf..a44c190cbc1d28 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'app', 'completion', 'audio', 'file' + 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/app.py b/api/services/errors/app.py deleted file mode 100644 index 7c4ca99c2ae869..00000000000000 --- a/api/services/errors/app.py +++ /dev/null @@ -1,2 +0,0 @@ -class MoreLikeThisDisabledError(Exception): - pass diff --git a/api/services/workflow/__init__.py b/api/services/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/workflow/defaults.py b/api/services/workflow/defaults.py new file mode 100644 index 00000000000000..67804fa4ebfc32 --- /dev/null +++ b/api/services/workflow/defaults.py @@ -0,0 +1,72 @@ +# default block config +default_block_configs = [ + { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + { + "role": "system", + "text": "You are a helpful AI assistant." + } + ] + }, + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" + }, + "stop": ["Human:"] + } + } + } + }, + { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "python3", + "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " + "+ arg2\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + }, + { + "type": "template-transform", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + } + ], + "template": "{{ arg1 }}" + } + }, + { + "type": "question-classifier", + "config": { + "instructions": "" # TODO + } + } +] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py new file mode 100644 index 00000000000000..c2fad83aaff4ce --- /dev/null +++ b/api/services/workflow/workflow_converter.py @@ -0,0 +1,259 @@ +import json +from typing import Optional + +from core.application_manager import ApplicationManager +from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \ + ExternalDataVariableEntity, DatasetEntity, VariableEntity +from core.model_runtime.utils import helper +from core.workflow.entities.NodeEntities import NodeType +from core.workflow.nodes.end.entities import EndNodeOutputType +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, ChatbotAppEngine +from models.workflow import Workflow, WorkflowType + + +class WorkflowConverter: + """ + App Convert to Workflow Mode + """ + + def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + """ + Convert to workflow mode + + - basic mode of chatbot app + + - advanced mode of assistant app (for migration) + + - completion app (for migration) + + :param app_model: App instance + :param account: Account instance + :return: workflow instance + """ + # get original app config + app_model_config = app_model.app_model_config + + # convert app model config + application_manager = ApplicationManager() + application_manager.convert_from_app_model_config_dict( + tenant_id=app_model.tenant_id, + app_model_config_dict=app_model_config.to_dict() + ) + + # init workflow graph + graph = { + "nodes": [], + "edges": [] + } + + # Convert list: + # - variables -> start + # - model_config -> llm + # - prompt_template -> llm + # - file_upload -> llm + # - external_data_variables -> http-request + # - dataset -> knowledge-retrieval + # - show_retrieve_source -> knowledge-retrieval + + # convert to start node + start_node = self._convert_to_start_node( + variables=app_model_config.variables + ) + + graph['nodes'].append(start_node) + + # convert to http request node + if app_model_config.external_data_variables: + http_request_node = self._convert_to_http_request_node( + external_data_variables=app_model_config.external_data_variables + ) + + graph = self._append_node(graph, http_request_node) + + # convert to knowledge retrieval node + if app_model_config.dataset: + knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( + dataset=app_model_config.dataset, + show_retrieve_source=app_model_config.show_retrieve_source + ) + + graph = self._append_node(graph, knowledge_retrieval_node) + + # convert to llm node + llm_node = self._convert_to_llm_node( + model_config=app_model_config.model_config, + prompt_template=app_model_config.prompt_template, + file_upload=app_model_config.file_upload + ) + + graph = self._append_node(graph, llm_node) + + # convert to end node by app mode + end_node = self._convert_to_end_node(app_model=app_model) + + graph = self._append_node(graph, end_node) + + # get new app mode + app_mode = self._get_new_app_mode(app_model) + + # create workflow record + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_mode).value, + version='draft', + graph=json.dumps(graph), + created_by=account.id + ) + + db.session.add(workflow) + db.session.flush() + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ + if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.commit() + + return workflow + + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: + """ + Convert to Start Node + :param variables: list of variables + :return: + """ + return { + "id": "start", + "position": None, + "data": { + "title": "START", + "type": NodeType.START.value, + "variables": [helper.dump_model(v) for v in variables] + } + } + + def _convert_to_http_request_node(self, external_data_variables: list[ExternalDataVariableEntity]) -> dict: + """ + Convert API Based Extension to HTTP Request Node + :param external_data_variables: list of external data variables + :return: + """ + # TODO: implement + pass + + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict: + """ + Convert datasets to Knowledge Retrieval Node + :param new_app_mode: new app mode + :param dataset: dataset + :return: + """ + # TODO: implement + if new_app_mode == AppMode.CHAT: + query_variable_selector = ["start", "sys.query"] + else: + pass + + return { + "id": "knowledge-retrieval", + "position": None, + "data": { + "title": "KNOWLEDGE RETRIEVAL", + "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + } + } + + def _convert_to_llm_node(self, model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileUploadEntity] = None) -> dict: + """ + Convert to LLM Node + :param model_config: model config + :param prompt_template: prompt template + :param file_upload: file upload config (optional) + """ + # TODO: implement + pass + + def _convert_to_end_node(self, app_model: App) -> dict: + """ + Convert to End Node + :param app_model: App instance + :return: + """ + if app_model.mode == AppMode.CHAT.value: + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + } + } + elif app_model.mode == "completion": + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": { + "type": EndNodeOutputType.PLAIN_TEXT.value, + "plain_text_selector": ["llm", "text"] + } + } + } + + def _create_edge(self, source: str, target: str) -> dict: + """ + Create Edge + :param source: source node id + :param target: target node id + :return: + """ + return { + "id": f"{source}-{target}", + "source": source, + "target": target + } + + def _append_node(self, graph: dict, node: dict) -> dict: + """ + Append Node to Graph + + :param graph: Graph, include: nodes, edges + :param node: Node to append + :return: + """ + previous_node = graph['nodes'][-1] + graph['nodes'].append(node) + graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + return graph + + def _get_new_app_mode(self, app_model: App) -> AppMode: + """ + Get new app mode + :param app_model: App instance + :return: AppMode + """ + if app_model.mode == "completion": + return AppMode.WORKFLOW + else: + return AppMode.value_of(app_model.mode) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py new file mode 100644 index 00000000000000..6a967e86ffd24f --- /dev/null +++ b/api/services/workflow_service.py @@ -0,0 +1,83 @@ +import json +from datetime import datetime + +from extensions.ext_database import db +from models.account import Account +from models.model import App, ChatbotAppEngine +from models.workflow import Workflow, WorkflowType +from services.workflow.defaults import default_block_configs +from services.workflow.workflow_converter import WorkflowConverter + + +class WorkflowService: + """ + Workflow Service + """ + + def get_draft_workflow(self, app_model: App) -> Workflow: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() + + # return draft workflow + return workflow + + def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: + """ + Sync draft workflow + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_model.mode).value, + version='draft', + graph=json.dumps(graph), + created_by=account.id + ) + db.session.add(workflow) + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.utcnow() + + # commit db session changes + db.session.commit() + + # return draft workflow + return workflow + + def get_default_block_configs(self) -> dict: + """ + Get default block configs + """ + # return default block config + return default_block_configs + + def chatbot_convert_to_workflow(self, app_model: App) -> Workflow: + """ + basic mode of chatbot app to workflow + + :param app_model: App instance + :return: + """ + # check if chatbot app is in basic mode + if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: + raise ValueError('Chatbot app already in workflow mode') + + # convert to workflow mode + workflow_converter = WorkflowConverter() + workflow = workflow_converter.convert_to_workflow(app_model=app_model) + + return workflow From c028e5f889b835f5bf8ed84e4f2ccad7879b3d0c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 03:20:28 +0800 Subject: [PATCH 006/160] add app convert codes --- api/controllers/console/app/conversation.py | 2 +- api/controllers/console/app/message.py | 2 +- api/controllers/console/app/workflow.py | 6 +- api/controllers/console/app/wraps.py | 2 +- api/core/app_runner/app_runner.py | 17 +- api/core/app_runner/basic_app_runner.py | 2 +- api/core/application_manager.py | 6 +- api/core/entities/application_entities.py | 1 - api/core/prompt/advanced_prompt_transform.py | 198 +++++++ .../generate_prompts/baichuan_chat.json | 6 +- .../generate_prompts/baichuan_completion.json | 4 +- .../prompt/generate_prompts/common_chat.json | 6 +- .../generate_prompts/common_completion.json | 4 +- api/core/prompt/prompt_builder.py | 10 - api/core/prompt/prompt_template.py | 3 +- api/core/prompt/prompt_transform.py | 548 +----------------- api/core/prompt/simple_prompt_transform.py | 298 ++++++++++ api/fields/annotation_fields.py | 1 - api/fields/workflow_fields.py | 1 - api/services/workflow/workflow_converter.py | 168 +++++- 20 files changed, 694 insertions(+), 591 deletions(-) create mode 100644 api/core/prompt/advanced_prompt_transform.py delete mode 100644 api/core/prompt/prompt_builder.py create mode 100644 api/core/prompt/simple_prompt_transform.py diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5d312149f714bb..daf96411211abd 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -21,7 +21,7 @@ ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, AppMode +from models.model import AppMode, Conversation, Message, MessageAnnotation class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 9a177116eacbc1..c384e878aaccf9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -26,7 +26,7 @@ from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 2794735bbb55a9..1bb0ea34c12340 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse, marshal_with +from flask_restful import Resource, marshal_with, reqparse from controllers.console import api from controllers.console.app.error import DraftWorkflowNotExist @@ -6,8 +6,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_fields import workflow_fields -from libs.login import login_required, current_user -from models.model import App, ChatbotAppEngine, AppMode +from libs.login import current_user, login_required +from models.model import App, AppMode, ChatbotAppEngine from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index fe35e723043469..1c2c4cf5c7c970 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,7 +5,7 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, ChatbotAppEngine, AppMode +from models.model import App, AppMode, ChatbotAppEngine def get_app_model(view: Optional[Callable] = None, *, diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index f9678b372fce6b..c6f6268a7a73dc 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -22,7 +22,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import SimplePromptTransform from models.model import App, Message, MessageAnnotation @@ -140,12 +140,11 @@ def organize_prompt_messages(self, app_record: App, :param memory: memory :return: """ - prompt_transform = PromptTransform() + prompt_transform = SimplePromptTransform() # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: prompt_messages, stop = prompt_transform.get_prompt( - app_mode=app_record.mode, prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -155,17 +154,7 @@ def organize_prompt_messages(self, app_record: App, model_config=model_config ) else: - prompt_messages = prompt_transform.get_advanced_prompt( - app_mode=app_record.mode, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - stop = model_config.stop + raise NotImplementedError("Advanced prompt is not supported yet.") return prompt_messages, stop diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 26e9cc84aa770e..0e0fe6e3bfa099 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -15,7 +15,7 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationException from extensions.ext_database import db -from models.model import App, Conversation, Message, AppMode +from models.model import App, AppMode, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 2fde422d4726b8..cf463be1df58f4 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -28,7 +28,8 @@ ModelConfigEntity, PromptTemplateEntity, SensitiveWordAvoidanceEntity, - TextToSpeechEntity, VariableEntity, + TextToSpeechEntity, + VariableEntity, ) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -541,8 +542,7 @@ def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_di query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] - ), - single_strategy=datasets.get('strategy', 'router') + ) ) ) else: diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 092591a73fbe00..f8f293d96a8146 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -156,7 +156,6 @@ def value_of(cls, value: str) -> 'RetrieveStrategy': query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - single_strategy: Optional[str] = None # for temp top_k: Optional[int] = None score_threshold: Optional[float] = None reranking_model: Optional[dict] = None diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py new file mode 100644 index 00000000000000..9ca3ef0375df0b --- /dev/null +++ b/api/core/prompt/advanced_prompt_transform.py @@ -0,0 +1,198 @@ +from typing import Optional + +from core.entities.application_entities import PromptTemplateEntity, ModelConfigEntity, \ + AdvancedCompletionPromptTemplateEntity +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, UserPromptMessage, \ + SystemPromptMessage, AssistantPromptMessage, TextPromptMessageContent +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import ModelMode + + +class AdvancePromptTransform(PromptTransform): + """ + Advanced Prompt Transform for Workflow LLM Node. + """ + + def get_prompt(self, prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + prompt_messages = [] + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + prompt_messages = self._get_completion_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + elif model_mode == ModelMode.CHAT: + prompt_messages = self._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages + + def _get_completion_model_prompt_messages(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + """ + Get completion model prompt messages. + """ + raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt + + prompt_messages = [] + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix + self._set_histories_variable( + memory=memory, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) + + prompt = prompt_template.format( + prompt_inputs + ) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) + + return prompt_messages + + def _get_chat_model_prompt_messages(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + """ + Get chat model prompt messages. + """ + raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages + + prompt_messages = [] + + for prompt_item in raw_prompt_list: + raw_prompt = prompt_item.text + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs + ) + + if prompt_item.role == PromptMessageRole.USER: + prompt_messages.append(UserPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.ASSISTANT: + prompt_messages.append(AssistantPromptMessage(content=prompt)) + + if memory: + self._append_chat_histories(memory, prompt_messages, model_config) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + elif files: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + + return prompt_messages + + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#context#' in prompt_template.variable_keys: + if context: + prompt_inputs['#context#'] = context + else: + prompt_inputs['#context#'] = '' + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#query#' in prompt_template.variable_keys: + if query: + prompt_inputs['#query#'] = query + else: + prompt_inputs['#query#'] = '' + + def _set_histories_variable(self, memory: TokenBufferMemory, + raw_prompt: str, + role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigEntity) -> None: + if '#histories#' in prompt_template.variable_keys: + if memory: + inputs = {'#histories#': '', **prompt_inputs} + prompt_template = PromptTemplateParser(raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage( + content=prompt_template.format(prompt_inputs) + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + + histories = self._get_history_messages_from_memory( + memory=memory, + max_token_limit=rest_tokens, + human_prefix=role_prefix.user, + ai_prefix=role_prefix.assistant + ) + prompt_inputs['#histories#'] = histories + else: + prompt_inputs['#histories#'] = '' diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/generate_prompts/baichuan_chat.json index 5bf83cd9c7634b..03b6a53cfff2d1 100644 --- a/api/core/prompt/generate_prompts/baichuan_chat.json +++ b/api/core/prompt/generate_prompts/baichuan_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "用户", "assistant_prefix": "助手", - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n", - "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n", + "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\n用户:{{query}}", + "query_prompt": "\n\n用户:{{#query#}}", "stops": ["用户:"] } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/generate_prompts/baichuan_completion.json index a3a2054e830b90..ae8c0dac53392f 100644 --- a/api/core/prompt/generate_prompts/baichuan_completion.json +++ b/api/core/prompt/generate_prompts/baichuan_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json index 709a8d88669d2d..d398a512e670a7 100644 --- a/api/core/prompt/generate_prompts/common_chat.json +++ b/api/core/prompt/generate_prompts/common_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "Human", "assistant_prefix": "Assistant", - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", - "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{histories}}\n\n\n", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", + "query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ", "stops": ["\nHuman:", ""] } diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/generate_prompts/common_completion.json index 9e7e8d68ef333b..c148772010fb05 100644 --- a/api/core/prompt/generate_prompts/common_completion.json +++ b/api/core/prompt/generate_prompts/common_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py deleted file mode 100644 index 7727b0f92e83eb..00000000000000 --- a/api/core/prompt/prompt_builder.py +++ /dev/null @@ -1,10 +0,0 @@ -from core.prompt.prompt_template import PromptTemplateParser - - -class PromptBuilder: - @classmethod - def parse_prompt(cls, prompt: str, inputs: dict) -> str: - prompt_template = PromptTemplateParser(prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt = prompt_template.format(prompt_inputs) - return prompt diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py index 32c5a791de4209..454f92e3b7dff5 100644 --- a/api/core/prompt/prompt_template.py +++ b/api/core/prompt/prompt_template.py @@ -32,7 +32,8 @@ def replacer(match): return PromptTemplateParser.remove_template_variables(value) return value - return re.sub(REGEX, replacer, self.template) + prompt = re.sub(REGEX, replacer, self.template) + return re.sub(r'<\|.*?\|>', '', prompt) @classmethod def remove_template_variables(cls, text: str): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index abbfa962494b33..c0f70ae0bbea78 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,393 +1,13 @@ -import enum -import json -import os -import re from typing import Optional, cast -from core.entities.application_entities import ( - AdvancedCompletionPromptTemplateEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.file.file_obj import FileObj +from core.entities.application_entities import ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) +from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import PromptTemplateParser -from models.model import AppMode - - -class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'ModelMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') class PromptTransform: - def get_prompt(self, - app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name( - app_mode=app_mode, - provider=model_config.provider, - model=model_config.model - )) - - if app_mode == AppMode.CHAT and model_mode == ModelMode.CHAT: - stops = None - - prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - else: - stops = prompt_rules.get('stops') - if stops is not None and len(stops) == 0: - stops = None - - prompt_messages = self._get_simple_others_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - return prompt_messages, stops - - def get_advanced_prompt(self, app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_messages = [] - - if app_mode == AppMode.CHAT: - if model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_chat_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif model_mode == ModelMode.CHAT: - prompt_messages = self._get_chat_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif app_mode == AppMode.COMPLETION: - if model_mode == ModelMode.CHAT: - prompt_messages = self._get_completion_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - context=context, - ) - elif model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_completion_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - context=context, - ) - - return prompt_messages - - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: - """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } - - if human_prefix: - kwargs['human_prefix'] = human_prefix - - if ai_prefix: - kwargs['ai_prefix'] = ai_prefix - - return memory.get_history_prompt_text( - **kwargs - ) - - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> list[PromptMessage]: - """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit - ) - - def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: - # baichuan - if provider == 'baichuan': - return self._prompt_file_name_for_baichuan(app_mode) - - baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): - return self._prompt_file_name_for_baichuan(app_mode) - - # common - if app_mode == AppMode.COMPLETION: - return 'common_completion' - else: - return 'common_chat' - - def _prompt_file_name_for_baichuan(self, app_mode: AppMode) -> str: - if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' - else: - return 'baichuan_chat' - - def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: - # Get the absolute path of the subdirectory - prompt_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'generate_prompts') - - json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') - # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: - return json.load(json_file) - - def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - prompt_messages = [] - - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - - self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_simple_others_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - files: list[FileObj], - model_config: ModelConfigEntity) -> list[PromptMessage]: - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' - - if memory and 'histories_prompt' in prompt_rules: - # append chat histories - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=prompt + query_prompt, - inputs={ - 'query': query - } - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' - ) - prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) - histories_prompt_content = prompt_template.format({'histories': histories}) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' - elif order == 'histories_prompt': - prompt += histories_prompt_content - - prompt_template = PromptTemplateParser(template=query_prompt) - query_prompt_content = prompt_template.format({'query': query}) - - prompt += query_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - model_mode = ModelMode.value_of(model_config.mode) - - if model_mode == ModelMode.CHAT and files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - prompt_message = UserPromptMessage(content=prompt) - - return [prompt_message] - - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#context#' in prompt_template.variable_keys: - if context: - prompt_inputs['#context#'] = context - else: - prompt_inputs['#context#'] = '' - - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#query#' in prompt_template.variable_keys: - if query: - prompt_inputs['#query#'] = query - else: - prompt_inputs['#query#'] = '' - - def _set_histories_variable(self, memory: TokenBufferMemory, - raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: - if '#histories#' in prompt_template.variable_keys: - if memory: - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=raw_prompt, - inputs={'#histories#': '', **prompt_inputs} - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant - ) - prompt_inputs['#histories#'] = histories - else: - prompt_inputs['#histories#'] = '' - def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> None: @@ -422,152 +42,28 @@ def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_conf return rest_tokens - def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str: - prompt = prompt_template.format( - prompt_inputs - ) - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - return prompt - - def _get_chat_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None) -> str: + """Get memory messages.""" + kwargs = { + "max_token_limit": max_token_limit + } - self._set_context_variable(context, prompt_template, prompt_inputs) + if human_prefix: + kwargs['human_prefix'] = human_prefix - self._set_query_variable(query, prompt_template, prompt_inputs) + if ai_prefix: + kwargs['ai_prefix'] = ai_prefix - self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config + return memory.get_history_prompt_text( + **kwargs ) - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_chat_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - self._append_chat_histories(memory, prompt_messages, model_config) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_completion_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - context: Optional[str]) -> list[PromptMessage]: - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_completion_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - files: list[FileObj], - context: Optional[str]) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - for prompt_message in prompt_messages[::-1]: - if prompt_message.role == PromptMessageRole.USER: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message.content = prompt_message_contents - break - - return prompt_messages + def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + max_token_limit: int) -> list[PromptMessage]: + """Get memory messages.""" + return memory.get_history_prompt_messages( + max_token_limit=max_token_limit + ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py new file mode 100644 index 00000000000000..a898c37c4a8062 --- /dev/null +++ b/api/core/prompt/simple_prompt_transform.py @@ -0,0 +1,298 @@ +import enum +import json +import os +from typing import Optional, Tuple + +from core.entities.application_entities import ( + ModelConfigEntity, + PromptTemplateEntity, +) +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform +from models.model import AppMode + + +class ModelMode(enum.Enum): + COMPLETION = 'completion' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'ModelMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + +prompt_file_contents = {} + + +class SimplePromptTransform(PromptTransform): + """ + Simple Prompt Transform for Chatbot App Basic Mode. + """ + def get_prompt(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> \ + tuple[list[PromptMessage], Optional[list[str]]]: + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.CHAT: + prompt_messages, stops = self._get_chat_model_prompt_messages( + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + else: + prompt_messages, stops = self._get_completion_model_prompt_messages( + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages, stops + + def get_prompt_str_and_rules(self, app_mode: AppMode, + model_config: ModelConfigEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> Tuple[str, dict]: + # get prompt template + prompt_template_config = self.get_prompt_template( + app_mode=app_mode, + provider=model_config.provider, + model=model_config.model, + pre_prompt=pre_prompt, + has_context=context is not None, + query_in_prompt=query is not None, + with_memory_prompt=histories is not None + ) + + variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + + for v in prompt_template_config['special_variable_keys']: + # support #context#, #query# and #histories# + if v == '#context#': + variables['#context#'] = context if context else '' + elif v == '#query#': + variables['#query#'] = query if query else '' + elif v == '#histories#': + variables['#histories#'] = histories if histories else '' + + prompt_template = prompt_template_config['prompt_template'] + prompt = prompt_template.format(variables) + + return prompt, prompt_template_config['prompt_rules'] + + def get_prompt_template(self, app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False) -> dict: + prompt_rules = self._get_prompt_rule( + app_mode=app_mode, + provider=provider, + model=model + ) + + custom_variable_keys = [] + special_variable_keys = [] + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt' and has_context: + prompt += prompt_rules['context_prompt'] + special_variable_keys.append('#context#') + elif order == 'pre_prompt' and pre_prompt: + prompt += pre_prompt + '\n' + pre_prompt_template = PromptTemplateParser(template=pre_prompt) + custom_variable_keys = pre_prompt_template.variable_keys + elif order == 'histories_prompt' and with_memory_prompt: + prompt += prompt_rules['histories_prompt'] + special_variable_keys.append('#histories#') + + if query_in_prompt: + prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' + special_variable_keys.append('#query#') + + return { + "prompt_template": PromptTemplateParser(template=prompt), + "custom_variable_keys": custom_variable_keys, + "special_variable_keys": special_variable_keys, + "prompt_rules": prompt_rules + } + + def _get_chat_model_prompt_messages(self, pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ + -> Tuple[list[PromptMessage], Optional[list[str]]]: + prompt_messages = [] + + # get prompt + prompt, _ = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + + self._append_chat_histories( + memory=memory, + prompt_messages=prompt_messages, + model_config=model_config + ) + + prompt_messages.append(self.get_last_user_message(query, files)) + + return prompt_messages, None + + def _get_completion_model_prompt_messages(self, pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ + -> Tuple[list[PromptMessage], Optional[list[str]]]: + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if memory: + tmp_human_message = UserPromptMessage( + content=prompt + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + histories = self._get_history_messages_from_memory( + memory=memory, + max_token_limit=rest_tokens, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + ) + + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context, + histories=histories + ) + + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + return [self.get_last_user_message(prompt, files)], stops + + def get_last_user_message(self, prompt: str, files: list[FileObj]) -> UserPromptMessage: + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_message = UserPromptMessage(content=prompt_message_contents) + else: + prompt_message = UserPromptMessage(content=prompt) + + return prompt_message + + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict: + """ + Get simple prompt rule. + :param app_mode: app mode + :param provider: model provider + :param model: model name + :return: + """ + prompt_file_name = self._prompt_file_name( + app_mode=app_mode, + provider=provider, + model=model + ) + + # Check if the prompt file is already loaded + if prompt_file_name in prompt_file_contents: + return prompt_file_contents[prompt_file_name] + + # Get the absolute path of the subdirectory + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') + json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + + # Open the JSON file and read its content + with open(json_file_path, encoding='utf-8') as json_file: + content = json.load(json_file) + + # Store the content of the prompt file + prompt_file_contents[prompt_file_name] = content + + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: + # baichuan + is_baichuan = False + if provider == 'baichuan': + is_baichuan = True + else: + baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] + if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + is_baichuan = True + + if is_baichuan: + if app_mode == AppMode.WORKFLOW: + return 'baichuan_completion' + else: + return 'baichuan_chat' + + # common + if app_mode == AppMode.WORKFLOW: + return 'common_completion' + else: + return 'common_chat' diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index d9cd6c03bb5b81..c77808447519fc 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,7 +2,6 @@ from libs.helper import TimestampField - annotation_fields = { "id": fields.String, "question": fields.String, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9dc92ea43bc03f..decdc0567f1934 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,7 +5,6 @@ from fields.member_fields import simple_account_fields from libs.helper import TimestampField - workflow_fields = { 'id': fields.String, 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c2fad83aaff4ce..7d18f4f675535a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -2,9 +2,17 @@ from typing import Optional from core.application_manager import ApplicationManager -from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \ - ExternalDataVariableEntity, DatasetEntity, VariableEntity +from core.entities.application_entities import ( + DatasetEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, DatasetRetrieveConfigEntity, +) +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper +from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db @@ -32,6 +40,9 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: :param account: Account instance :return: workflow instance """ + # get new app mode + new_app_mode = self._get_new_app_mode(app_model) + # get original app config app_model_config = app_model.app_model_config @@ -75,14 +86,17 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: # convert to knowledge retrieval node if app_model_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - dataset=app_model_config.dataset, - show_retrieve_source=app_model_config.show_retrieve_source + new_app_mode=new_app_mode, + dataset_config=app_model_config.dataset ) - graph = self._append_node(graph, knowledge_retrieval_node) + if knowledge_retrieval_node: + graph = self._append_node(graph, knowledge_retrieval_node) # convert to llm node llm_node = self._convert_to_llm_node( + new_app_mode=new_app_mode, + graph=graph, model_config=app_model_config.model_config, prompt_template=app_model_config.prompt_template, file_upload=app_model_config.file_upload @@ -95,14 +109,11 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: graph = self._append_node(graph, end_node) - # get new app mode - app_mode = self._get_new_app_mode(app_model) - # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, app_id=app_model.id, - type=WorkflowType.from_app_mode(app_mode).value, + type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), created_by=account.id @@ -124,7 +135,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: new_app_model_config.completion_prompt_config = '' new_app_model_config.dataset_configs = '' new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value + if new_app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value new_app_model_config.workflow_id = workflow.id db.session.add(new_app_model_config) @@ -157,18 +168,22 @@ def _convert_to_http_request_node(self, external_data_variables: list[ExternalDa # TODO: implement pass - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict: + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ + -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode - :param dataset: dataset + :param dataset_config: dataset :return: """ - # TODO: implement + retrieve_config = dataset_config.retrieve_config if new_app_mode == AppMode.CHAT: query_variable_selector = ["start", "sys.query"] + elif retrieve_config.query_variable: + # fetch query variable + query_variable_selector = ["start", retrieve_config.query_variable] else: - pass + return None return { "id": "knowledge-retrieval", @@ -176,20 +191,139 @@ def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: D "data": { "title": "KNOWLEDGE RETRIEVAL", "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "query_variable_selector": query_variable_selector, + "dataset_ids": dataset_config.dataset_ids, + "retrieval_mode": retrieve_config.retrieve_strategy.value, + "multiple_retrieval_config": { + "top_k": retrieve_config.top_k, + "score_threshold": retrieve_config.score_threshold, + "reranking_model": retrieve_config.reranking_model + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + else None, } } - def _convert_to_llm_node(self, model_config: ModelConfigEntity, + def _convert_to_llm_node(self, new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, file_upload: Optional[FileUploadEntity] = None) -> dict: """ Convert to LLM Node + :param new_app_mode: new app mode + :param graph: graph :param model_config: model config :param prompt_template: prompt template :param file_upload: file upload config (optional) """ - # TODO: implement - pass + # fetch start and knowledge retrieval node + start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) + knowledge_retrieval_node = next(filter( + lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, + graph['nodes'] + ), None) + + role_prefix = None + + # Chat Model + if model_config.mode == LLMMode.CHAT.value: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = [ + { + "role": 'user', + "text": prompt_template_config['prompt_template'].template + } + ] + else: + advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template + prompts = [helper.dump_model(m) for m in advanced_chat_prompt_template.messages] \ + if advanced_chat_prompt_template else [] + # Completion Model + else: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = { + "text": prompt_template_config['prompt_template'].template + } + + prompt_rules = prompt_template_config['prompt_rules'] + role_prefix = { + "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + } + else: + advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template + prompts = { + "text": advanced_completion_prompt_template.prompt, + } if advanced_completion_prompt_template else {"text": ""} + + if advanced_completion_prompt_template.role_prefix: + role_prefix = { + "user": advanced_completion_prompt_template.role_prefix.user, + "assistant": advanced_completion_prompt_template.role_prefix.assistant + } + + memory = None + if new_app_mode == AppMode.CHAT: + memory = { + "role_prefix": role_prefix, + "window": { + "enabled": False + } + } + + return { + "id": "llm", + "position": None, + "data": { + "title": "LLM", + "type": NodeType.LLM.value, + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": model_config.parameters.update({"stop": model_config.stop}) + }, + "variables": [{ + "variable": v['variable'], + "value_selector": ["start", v['variable']] + } for v in start_node['data']['variables']], + "prompts": prompts, + "memory": memory, + "context": { + "enabled": knowledge_retrieval_node is not None, + "variable_selector": ["knowledge-retrieval", "result"] + if knowledge_retrieval_node is not None else None + }, + "vision": { + "enabled": file_upload is not None, + "variable_selector": ["start", "sys.files"] if file_upload is not None else None, + "configs": { + "detail": file_upload.image_config['detail'] + } if file_upload is not None else None + } + } + } def _convert_to_end_node(self, app_model: App) -> dict: """ From 8642354a2aaf7ad6b758048c9d50ef5ee5efb195 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 03:20:39 +0800 Subject: [PATCH 007/160] lint --- api/core/prompt/advanced_prompt_transform.py | 17 +++++++++++++---- api/core/prompt/simple_prompt_transform.py | 8 ++++---- api/services/workflow/workflow_converter.py | 3 ++- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 9ca3ef0375df0b..397f708f1f8ea2 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,11 +1,20 @@ from typing import Optional -from core.entities.application_entities import PromptTemplateEntity, ModelConfigEntity, \ - AdvancedCompletionPromptTemplateEntity +from core.entities.application_entities import ( + AdvancedCompletionPromptTemplateEntity, + ModelConfigEntity, + PromptTemplateEntity, +) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, UserPromptMessage, \ - SystemPromptMessage, AssistantPromptMessage, TextPromptMessageContent +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a898c37c4a8062..6e158bef3932f5 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,7 @@ import enum import json import os -from typing import Optional, Tuple +from typing import Optional from core.entities.application_entities import ( ModelConfigEntity, @@ -85,7 +85,7 @@ def get_prompt_str_and_rules(self, app_mode: AppMode, query: Optional[str] = None, context: Optional[str] = None, histories: Optional[str] = None, - ) -> Tuple[str, dict]: + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -160,7 +160,7 @@ def _get_chat_model_prompt_messages(self, pre_prompt: str, files: list[FileObj], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) \ - -> Tuple[list[PromptMessage], Optional[list[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -193,7 +193,7 @@ def _get_completion_model_prompt_messages(self, pre_prompt: str, files: list[FileObj], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) \ - -> Tuple[list[PromptMessage], Optional[list[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=AppMode.CHAT, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 7d18f4f675535a..647713b404cc90 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -4,11 +4,12 @@ from core.application_manager import ApplicationManager from core.entities.application_entities import ( DatasetEntity, + DatasetRetrieveConfigEntity, ExternalDataVariableEntity, FileUploadEntity, ModelConfigEntity, PromptTemplateEntity, - VariableEntity, DatasetRetrieveConfigEntity, + VariableEntity, ) from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper From 3b234febf5a04565b92590ec077b079fd20a4578 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 15:15:42 +0800 Subject: [PATCH 008/160] fix bugs and add unit tests --- api/core/prompt/simple_prompt_transform.py | 35 +-- api/models/workflow.py | 4 +- api/tests/unit_tests/.gitignore | 1 + api/tests/unit_tests/__init__.py | 0 api/tests/unit_tests/conftest.py | 7 + api/tests/unit_tests/core/__init__.py | 0 api/tests/unit_tests/core/prompt/__init__.py | 0 .../core/prompt/test_prompt_transform.py | 47 ++++ .../prompt/test_simple_prompt_transform.py | 216 ++++++++++++++++++ 9 files changed, 292 insertions(+), 18 deletions(-) create mode 100644 api/tests/unit_tests/.gitignore create mode 100644 api/tests/unit_tests/__init__.py create mode 100644 api/tests/unit_tests/conftest.py create mode 100644 api/tests/unit_tests/core/__init__.py create mode 100644 api/tests/unit_tests/core/prompt/__init__.py create mode 100644 api/tests/unit_tests/core/prompt/test_prompt_transform.py create mode 100644 api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 6e158bef3932f5..a51cc86e8b3268 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. """ + def get_prompt(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, @@ -154,12 +155,12 @@ def get_prompt_template(self, app_mode: AppMode, } def _get_chat_model_prompt_messages(self, pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -169,7 +170,7 @@ def _get_chat_model_prompt_messages(self, pre_prompt: str, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, - query=query, + query=None, context=context ) @@ -187,12 +188,12 @@ def _get_chat_model_prompt_messages(self, pre_prompt: str, return prompt_messages, None def _get_completion_model_prompt_messages(self, pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( @@ -259,7 +260,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict provider=provider, model=model ) - + # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] @@ -267,14 +268,16 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') - + # Open the JSON file and read its content with open(json_file_path, encoding='utf-8') as json_file: content = json.load(json_file) - + # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content + return content + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False diff --git a/api/models/workflow.py b/api/models/workflow.py index ed26e98896d286..95805e787129d5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,7 +5,6 @@ from extensions.ext_database import db from models.account import Account -from models.model import AppMode class WorkflowType(Enum): @@ -29,13 +28,14 @@ def value_of(cls, value: str) -> 'WorkflowType': raise ValueError(f'invalid workflow type value {value}') @classmethod - def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': """ Get workflow type from app mode. :param app_mode: app mode :return: workflow type """ + from models.model import AppMode app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT diff --git a/api/tests/unit_tests/.gitignore b/api/tests/unit_tests/.gitignore new file mode 100644 index 00000000000000..426667562b31da --- /dev/null +++ b/api/tests/unit_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/unit_tests/__init__.py b/api/tests/unit_tests/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py new file mode 100644 index 00000000000000..afc9802cf1cbe7 --- /dev/null +++ b/api/tests/unit_tests/conftest.py @@ -0,0 +1,7 @@ +import os + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) diff --git a/api/tests/unit_tests/core/__init__.py b/api/tests/unit_tests/core/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/prompt/__init__.py b/api/tests/unit_tests/core/prompt/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py new file mode 100644 index 00000000000000..8a260b05072c1b --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +from core.entities.application_entities import ModelConfigEntity +from core.entities.provider_configuration import ProviderModelBundle +from core.model_runtime.entities.message_entities import UserPromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.prompt_transform import PromptTransform + + +def test__calculate_rest_token(): + model_schema_mock = MagicMock(spec=AIModelEntity) + parameter_rule_mock = MagicMock(spec=ParameterRule) + parameter_rule_mock.name = 'max_tokens' + model_schema_mock.parameter_rules = [ + parameter_rule_mock + ] + model_schema_mock.model_properties = { + ModelPropertyKey.CONTEXT_SIZE: 62 + } + + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens.return_value = 6 + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.model = 'gpt-4' + model_config_mock.credentials = {} + model_config_mock.parameters = { + 'max_tokens': 50 + } + model_config_mock.model_schema = model_schema_mock + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + prompt_transform = PromptTransform() + + prompt_messages = [UserPromptMessage(content="Hello, how are you?")] + rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + + # Validate based on the mock configuration and expected logic + expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters['max_tokens'] + - large_language_model_mock.get_num_tokens.return_value) + assert rest_tokens == expected_rest_tokens + assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py new file mode 100644 index 00000000000000..cb6ad02541d2c5 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +from core.entities.application_entities import ModelConfigEntity +from core.prompt.simple_prompt_transform import SimplePromptTransform +from models.model import AppMode + + +def test_get_common_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_baichuan_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_common_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_baichuan_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + print(prompt_template['prompt_template'].template) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_q(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] + assert prompt_template['special_variable_keys'] == ['#query#'] + + +def test_get_common_chat_app_prompt_template_with_cq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_p(): + prompt_transform = SimplePromptTransform() + pre_prompt = "you are {{name}}" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=False, + with_memory_prompt=False, + ) + assert prompt_template['prompt_template'].template == pre_prompt + '\n' + assert prompt_template['custom_variable_keys'] == ['name'] + assert prompt_template['special_variable_keys'] == [] + + +def test__get_chat_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=None, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=False, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context} + real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 2 + assert prompt_messages[0].content == real_system_prompt + assert prompt_messages[1].content == query + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=None, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context, '#query#': query} + real_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 1 + assert stops == prompt_template['prompt_rules'].get('stops') + assert prompt_messages[0].content == real_prompt From 6aecf42b6e5d05659ba589f62dc1d6645ba85de9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 22:32:33 +0800 Subject: [PATCH 009/160] fix prompt transform bugs --- api/core/prompt/advanced_prompt_transform.py | 26 ++- api/core/prompt/prompt_transform.py | 4 +- api/core/prompt/simple_prompt_transform.py | 2 +- .../prompt/test_advanced_prompt_transform.py | 193 ++++++++++++++++++ .../prompt/test_simple_prompt_transform.py | 46 ++++- 5 files changed, 251 insertions(+), 20 deletions(-) create mode 100644 api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 397f708f1f8ea2..0ed9ec352cef96 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -20,7 +20,7 @@ from core.prompt.simple_prompt_transform import ModelMode -class AdvancePromptTransform(PromptTransform): +class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ @@ -74,10 +74,10 @@ def _get_completion_model_prompt_messages(self, prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - self._set_histories_variable( + prompt_inputs = self._set_histories_variable( memory=memory, raw_prompt=raw_prompt, role_prefix=role_prefix, @@ -104,7 +104,7 @@ def _get_completion_model_prompt_messages(self, def _get_chat_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - query: str, + query: Optional[str], files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], @@ -122,7 +122,7 @@ def _get_chat_model_prompt_messages(self, prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) prompt = prompt_template.format( prompt_inputs @@ -136,7 +136,7 @@ def _get_chat_model_prompt_messages(self, prompt_messages.append(AssistantPromptMessage(content=prompt)) if memory: - self._append_chat_histories(memory, prompt_messages, model_config) + prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) if files: prompt_message_contents = [TextPromptMessageContent(data=query)] @@ -157,7 +157,7 @@ def _get_chat_model_prompt_messages(self, last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -165,26 +165,30 @@ def _get_chat_model_prompt_messages(self, return prompt_messages - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: if '#context#' in prompt_template.variable_keys: if context: prompt_inputs['#context#'] = context else: prompt_inputs['#context#'] = '' - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + return prompt_inputs + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: if '#query#' in prompt_template.variable_keys: if query: prompt_inputs['#query#'] = query else: prompt_inputs['#query#'] = '' + return prompt_inputs + def _set_histories_variable(self, memory: TokenBufferMemory, raw_prompt: str, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: + model_config: ModelConfigEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} @@ -205,3 +209,5 @@ def _set_histories_variable(self, memory: TokenBufferMemory, prompt_inputs['#histories#'] = histories else: prompt_inputs['#histories#'] = '' + + return prompt_inputs diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index c0f70ae0bbea78..9596976b6eff77 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -10,12 +10,14 @@ class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: ModelConfigEntity) -> None: + model_config: ModelConfigEntity) -> list[PromptMessage]: if memory: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) + return prompt_messages + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: rest_tokens = 2000 diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a51cc86e8b3268..2f98fbcae8a1c1 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -177,7 +177,7 @@ def _get_chat_model_prompt_messages(self, pre_prompt: str, if prompt: prompt_messages.append(SystemPromptMessage(content=prompt)) - self._append_chat_histories( + prompt_messages = self._append_chat_histories( memory=memory, prompt_messages=prompt_messages, model_config=model_config diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py new file mode 100644 index 00000000000000..65a160a8e592de --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock + +import pytest + +from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ + ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity +from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.prompt_template import PromptTemplateParser +from models.model import Conversation + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt=prompt_template, + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + inputs = { + "name": "John" + } + files = [] + context = "I am superman." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_completion_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 1 + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ + "#context#": context, + "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " + f"{prompt.content}" for prompt in history_prompt_messages]), + **inputs, + }) + + +def test__get_chat_model_prompt_messages(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [] + query = "Hi2." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi1."), + AssistantPromptMessage(content="Hello1!") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 6 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + assert prompt_messages[5].content == query + + +def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=None, + files=files, + context=context, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 3 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + + +def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [ + FileObj( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + url="https://example.com/image1.jpg", + file_config={ + "image": { + "detail": "high", + } + } + ) + ] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=None, + files=files, + context=context, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + assert isinstance(prompt_messages[3].content, list) + assert len(prompt_messages[3].content) == 2 + assert prompt_messages[3].content[1].data == files[0].url + + +@pytest.fixture +def get_chat_model_args(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ) + ) + + inputs = { + "name": "John" + } + + context = "I am superman." + + return model_config_mock, prompt_template_entity, inputs, context diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index cb6ad02541d2c5..c174983e384e99 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,8 +1,10 @@ from unittest.mock import MagicMock from core.entities.application_entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import AppMode +from models.model import AppMode, Conversation def test_get_common_chat_app_prompt_template_with_pcqm(): @@ -141,7 +143,16 @@ def test__get_chat_model_prompt_messages(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' + memory_mock = MagicMock(spec=TokenBufferMemory) + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory_mock.get_history_prompt_messages.return_value = history_prompt_messages + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + pre_prompt = "You are a helpful assistant {{name}}." inputs = { "name": "John" @@ -154,7 +165,7 @@ def test__get_chat_model_prompt_messages(): query=query, files=[], context=context, - memory=None, + memory=memory_mock, model_config=model_config_mock ) @@ -171,9 +182,11 @@ def test__get_chat_model_prompt_messages(): full_inputs = {**inputs, '#context#': context} real_system_prompt = prompt_template['prompt_template'].format(full_inputs) - assert len(prompt_messages) == 2 + assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt - assert prompt_messages[1].content == query + assert prompt_messages[1].content == history_prompt_messages[0].content + assert prompt_messages[2].content == history_prompt_messages[1].content + assert prompt_messages[3].content == query def test__get_completion_model_prompt_messages(): @@ -181,7 +194,19 @@ def test__get_completion_model_prompt_messages(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." inputs = { "name": "John" @@ -194,7 +219,7 @@ def test__get_completion_model_prompt_messages(): query=query, files=[], context=context, - memory=None, + memory=memory, model_config=model_config_mock ) @@ -205,12 +230,17 @@ def test__get_completion_model_prompt_messages(): pre_prompt=pre_prompt, has_context=True, query_in_prompt=True, - with_memory_prompt=False, + with_memory_prompt=True, ) - full_inputs = {**inputs, '#context#': context, '#query#': query} + prompt_rules = prompt_template['prompt_rules'] + full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( + max_token_limit=2000, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + )} real_prompt = prompt_template['prompt_template'].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_template['prompt_rules'].get('stops') + assert stops == prompt_rules.get('stops') assert prompt_messages[0].content == real_prompt From 45621ba4d7b8d95a6f2b78b27ad8ab3a04eb198a Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 23 Feb 2024 14:58:03 +0800 Subject: [PATCH 010/160] add api extension to http request node convert --- api/core/features/external_data_fetch.py | 7 - api/services/workflow/workflow_converter.py | 149 ++++++++++++++++++-- 2 files changed, 135 insertions(+), 21 deletions(-) diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 7f23c8ed728096..ef37f055289cb4 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -1,5 +1,4 @@ import concurrent -import json import logging from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -28,12 +27,6 @@ def fetch(self, tenant_id: str, :param query: the query :return: the filled inputs """ - # Group tools by type and config - grouped_tools = {} - for tool in external_data_tools: - tool_key = (tool.type, json.dumps(tool.config, sort_keys=True)) - grouped_tools.setdefault(tool_key, []).append(tool) - results = {} with ThreadPoolExecutor() as executor: futures = {} diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 647713b404cc90..1fb37afe010ffb 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -11,6 +11,7 @@ PromptTemplateEntity, VariableEntity, ) +from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper from core.prompt.simple_prompt_transform import SimplePromptTransform @@ -18,6 +19,7 @@ from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db from models.account import Account +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, ChatbotAppEngine from models.workflow import Workflow, WorkflowType @@ -49,7 +51,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: # convert app model config application_manager = ApplicationManager() - application_manager.convert_from_app_model_config_dict( + app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, app_model_config_dict=app_model_config.to_dict() ) @@ -71,24 +73,27 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: # convert to start node start_node = self._convert_to_start_node( - variables=app_model_config.variables + variables=app_orchestration_config_entity.variables ) graph['nodes'].append(start_node) # convert to http request node - if app_model_config.external_data_variables: - http_request_node = self._convert_to_http_request_node( - external_data_variables=app_model_config.external_data_variables + if app_orchestration_config_entity.external_data_variables: + http_request_nodes = self._convert_to_http_request_node( + app_model=app_model, + variables=app_orchestration_config_entity.variables, + external_data_variables=app_orchestration_config_entity.external_data_variables ) - graph = self._append_node(graph, http_request_node) + for http_request_node in http_request_nodes: + graph = self._append_node(graph, http_request_node) # convert to knowledge retrieval node - if app_model_config.dataset: + if app_orchestration_config_entity.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=app_model_config.dataset + dataset_config=app_orchestration_config_entity.dataset ) if knowledge_retrieval_node: @@ -98,9 +103,9 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: llm_node = self._convert_to_llm_node( new_app_mode=new_app_mode, graph=graph, - model_config=app_model_config.model_config, - prompt_template=app_model_config.prompt_template, - file_upload=app_model_config.file_upload + model_config=app_orchestration_config_entity.model_config, + prompt_template=app_orchestration_config_entity.prompt_template, + file_upload=app_orchestration_config_entity.file_upload ) graph = self._append_node(graph, llm_node) @@ -160,14 +165,130 @@ def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: } } - def _convert_to_http_request_node(self, external_data_variables: list[ExternalDataVariableEntity]) -> dict: + def _convert_to_http_request_node(self, app_model: App, + variables: list[VariableEntity], + external_data_variables: list[ExternalDataVariableEntity]) -> list[dict]: """ Convert API Based Extension to HTTP Request Node + :param app_model: App instance + :param variables: list of variables :param external_data_variables: list of external data variables :return: """ - # TODO: implement - pass + index = 1 + nodes = [] + tenant_id = app_model.tenant_id + for external_data_variable in external_data_variables: + tool_type = external_data_variable.type + if tool_type != "api": + continue + + tool_variable = external_data_variable.variable + tool_config = external_data_variable.config + + # get params from config + api_based_extension_id = tool_config.get("api_based_extension_id") + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid" + .format(tool_variable)) + + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=tenant_id, + token=api_based_extension.api_key + ) + + http_request_variables = [] + inputs = {} + for v in variables: + http_request_variables.append({ + "variable": v.variable, + "value_selector": ["start", v.variable] + }) + + inputs[v.variable] = '{{' + v.variable + '}}' + + if app_model.mode == AppMode.CHAT.value: + http_request_variables.append({ + "variable": "_query", + "value_selector": ["start", "sys.query"] + }) + + request_body = { + 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + 'params': { + 'app_id': app_model.id, + 'tool_variable': tool_variable, + 'inputs': inputs, + 'query': '{{_query}}' if app_model.mode == AppMode.CHAT.value else '' + } + } + + request_body_json = json.dumps(request_body) + request_body_json = request_body_json.replace('\{\{', '{{').replace('\}\}', '}}') + + http_request_node = { + "id": f"http-request-{index}", + "position": None, + "data": { + "title": f"HTTP REQUEST {api_based_extension.name}", + "type": NodeType.HTTP_REQUEST.value, + "variables": http_request_variables, + "method": "post", + "url": api_based_extension.api_endpoint, + "authorization": { + "type": "api-key", + "config": { + "type": "bearer", + "api_key": api_key + } + }, + "headers": "", + "params": "", + "body": { + "type": "json", + "data": request_body_json + } + } + } + index += 1 + + nodes.append(http_request_node) + + # append code node for response body parsing + code_node = { + "id": f"code-{index}", + "position": None, + "data": { + "title": f"Parse {api_based_extension.name} response", + "type": NodeType.CODE.value, + "variables": [{ + "variable": "response_json", + "value_selector": [http_request_node['id'], "body"] + }], + "code_language": "python3", + "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" + "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "string" + } + ] + } + } + + nodes.append(code_node) + + return nodes def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ -> Optional[dict]: From 0806b3163ab45f8149acc493bb7b5c33095ebe65 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 23 Feb 2024 18:18:49 +0800 Subject: [PATCH 011/160] add to http request node convert tests --- api/core/application_manager.py | 8 +- api/core/entities/application_entities.py | 1 + api/services/app_model_config_service.py | 2 +- api/services/workflow/workflow_converter.py | 24 ++- api/tests/unit_tests/services/__init__.py | 0 .../unit_tests/services/workflow/__init__.py | 0 .../workflow/test_workflow_converter.py | 184 ++++++++++++++++++ 7 files changed, 210 insertions(+), 9 deletions(-) create mode 100644 api/tests/unit_tests/services/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_converter.py diff --git a/api/core/application_manager.py b/api/core/application_manager.py index cf463be1df58f4..77bb81b0da0d9d 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -400,10 +400,14 @@ def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_di config=val['config'] ) ) - elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]: + elif typ in [ + VariableEntity.Type.TEXT_INPUT.value, + VariableEntity.Type.PARAGRAPH.value, + VariableEntity.Type.NUMBER.value, + ]: properties['variables'].append( VariableEntity( - type=VariableEntity.Type.TEXT_INPUT, + type=VariableEntity.Type.value_of(typ), variable=variable[typ].get('variable'), description=variable[typ].get('description'), label=variable[typ].get('label'), diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index f8f293d96a8146..667940f184eb52 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -94,6 +94,7 @@ class Type(Enum): TEXT_INPUT = 'text-input' SELECT = 'select' PARAGRAPH = 'paragraph' + NUMBER = 'number' @classmethod def value_of(cls, value: str) -> 'VariableEntity.Type': diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 3ac11c645c9940..aa8cd73ea75daf 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -205,7 +205,7 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "external_data_tool"]: + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1fb37afe010ffb..31df58a583b526 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -190,10 +190,10 @@ def _convert_to_http_request_node(self, app_model: App, api_based_extension_id = tool_config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = self._get_api_based_extension( + tenant_id=tenant_id, + api_based_extension_id=api_based_extension_id + ) if not api_based_extension: raise ValueError("[External data tool] API query failed, variable: {}, " @@ -259,7 +259,6 @@ def _convert_to_http_request_node(self, app_model: App, } } } - index += 1 nodes.append(http_request_node) @@ -268,7 +267,7 @@ def _convert_to_http_request_node(self, app_model: App, "id": f"code-{index}", "position": None, "data": { - "title": f"Parse {api_based_extension.name} response", + "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, "variables": [{ "variable": "response_json", @@ -287,6 +286,7 @@ def _convert_to_http_request_node(self, app_model: App, } nodes.append(code_node) + index += 1 return nodes @@ -513,3 +513,15 @@ def _get_new_app_mode(self, app_model: App) -> AppMode: return AppMode.WORKFLOW else: return AppMode.value_of(app_model.mode) + + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + """ + Get API Based Extension + :param tenant_id: tenant id + :param api_based_extension_id: api based extension id + :return: + """ + return db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() diff --git a/api/tests/unit_tests/services/__init__.py b/api/tests/unit_tests/services/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/services/workflow/__init__.py b/api/tests/unit_tests/services/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py new file mode 100644 index 00000000000000..69cf6afe45aa4d --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -0,0 +1,184 @@ +# test for api/services/workflow/workflow_converter.py +import json +from unittest.mock import MagicMock + +import pytest + +from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity +from core.helper import encrypter +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import AppMode +from services.workflow.workflow_converter import WorkflowConverter + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity( + variable="text-input", + label="text-input", + type=VariableEntity.Type.TEXT_INPUT + ), + VariableEntity( + variable="paragraph", + label="paragraph", + type=VariableEntity.Type.PARAGRAPH + ), + VariableEntity( + variable="select", + label="select", + type=VariableEntity.Type.SELECT + ) + ] + + +def test__convert_to_start_node(default_variables): + # act + result = WorkflowConverter()._convert_to_start_node(default_variables) + + # assert + assert result["data"]["variables"][0]["variable"] == "text-input" + assert result["data"]["variables"][1]["variable"] == "paragraph" + assert result["data"]["variables"][2]["variable"] == "select" + + +def test__convert_to_http_request_node(default_variables): + """ + Test convert to http request nodes + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.CHAT.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 4 # appended _query variable + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "{{_query}}" # for chatbot + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_http_request_node_for_workflow_app(default_variables): + """ + Test convert to http request nodes for workflow app + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.WORKFLOW.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 3 + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "" + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" From f11bf9153deee59d773d30d073e272d22f0082bc Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 13:47:43 +0800 Subject: [PATCH 012/160] add more tests --- .../workflow/test_workflow_converter.py | 266 +++++++++++++++++- 1 file changed, 263 insertions(+), 3 deletions(-) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 69cf6afe45aa4d..ee9e5eb2fa039e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -4,8 +4,12 @@ import pytest -from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity +from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ + DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ + AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import AppMode from services.workflow.workflow_converter import WorkflowConverter @@ -42,9 +46,9 @@ def test__convert_to_start_node(default_variables): assert result["data"]["variables"][2]["variable"] == "select" -def test__convert_to_http_request_node(default_variables): +def test__convert_to_http_request_node_for_chatbot(default_variables): """ - Test convert to http request nodes + Test convert to http request nodes for chatbot :return: """ app_model = MagicMock() @@ -182,3 +186,259 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): code_node = nodes[1] assert code_node["data"]["type"] == "code" + + +def test__convert_to_knowledge_retrieval_node_for_chatbot(): + new_app_mode = AppMode.CHAT + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", "sys.query"] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_knowledge_retrieval_node_for_workflow_app(): + new_app_mode = AppMode.WORKFLOW + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable="query", + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ]) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], list) + assert len(llm_node["data"]["prompts"]) == len(prompt_template.advanced_chat_prompt_template.messages) + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.advanced_chat_prompt_template.messages[0].text + + +def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], dict) + assert llm_node["data"]["prompts"]['text'] == prompt_template.advanced_completion_prompt_template.prompt From 7458fde5a51f593376aedeafb78a8cac9cdb146d Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 14:40:52 +0800 Subject: [PATCH 013/160] add agent app convert command --- api/commands.py | 55 ++++++++++++++++++++++++- api/controllers/console/app/workflow.py | 5 ++- api/services/workflow_service.py | 5 ++- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/api/commands.py b/api/commands.py index 250039a3650c2b..9a023b1c485432 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,7 +15,7 @@ from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, MessageAnnotation +from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel @@ -370,8 +370,61 @@ def migrate_knowledge_vector_database(): fg='green')) +@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +def convert_to_agent_apps(): + """ + Convert Agent Assistant to Agent App. + """ + click.echo(click.style('Start convert to agent apps.', fg='green')) + + proceeded_app_ids = [] + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id AS id FROM apps a +INNER JOIN app_model_configs am ON a.app_model_config_id=am.id +WHERE a.mode = 'chat' AND am.agent_mode is not null +and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') +and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.query(App).filter(App.id == app_id).first() + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo('Converting app: {}'.format(app.id)) + + try: + app.mode = AppMode.AGENT.value + db.session.commit() + + # update conversation mode to agent + db.session.query(Conversation).filter(Conversation.app_id == app.id).update( + {Conversation.mode: AppMode.AGENT.value} + ) + + db.session.commit() + click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + except Exception as e: + click.echo( + click.style('Convert app error: {} {}'.format(e.__class__.__name__, + str(e)), fg='red')) + + click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) + app.cli.add_command(convert_to_agent_apps) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1bb0ea34c12340..7663e22580aff9 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -77,7 +77,10 @@ def post(self, app_model: App): """ # convert to workflow mode workflow_service = WorkflowService() - workflow = workflow_service.chatbot_convert_to_workflow(app_model=app_model) + workflow = workflow_service.chatbot_convert_to_workflow( + app_model=app_model, + account=current_user + ) # return workflow return workflow diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6a967e86ffd24f..0cb398225d49b4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -65,11 +65,12 @@ def get_default_block_configs(self) -> dict: # return default block config return default_block_configs - def chatbot_convert_to_workflow(self, app_model: App) -> Workflow: + def chatbot_convert_to_workflow(self, app_model: App, account: Account) -> Workflow: """ basic mode of chatbot app to workflow :param app_model: App instance + :param account: Account instance :return: """ # check if chatbot app is in basic mode @@ -78,6 +79,6 @@ def chatbot_convert_to_workflow(self, app_model: App) -> Workflow: # convert to workflow mode workflow_converter = WorkflowConverter() - workflow = workflow_converter.convert_to_workflow(app_model=app_model) + workflow = workflow_converter.convert_to_workflow(app_model=app_model, account=account) return workflow From 2ba7ac8bc1f0b9d7cf49a2f5cd9d2f3bf19681a3 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 15:52:08 +0800 Subject: [PATCH 014/160] add expert mode of chatapp convert command --- api/commands.py | 72 ++++++++++++++++++- api/core/application_manager.py | 41 ++++++----- api/core/entities/application_entities.py | 2 +- api/services/workflow/workflow_converter.py | 23 +++--- api/services/workflow_service.py | 2 +- .../workflow/test_workflow_converter.py | 2 + 6 files changed, 114 insertions(+), 28 deletions(-) diff --git a/api/commands.py b/api/commands.py index 9a023b1c485432..73d2150de23df9 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,5 +1,6 @@ import base64 import json +import logging import secrets import click @@ -12,11 +13,12 @@ from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant +from models.account import Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel +from services.workflow.workflow_converter import WorkflowConverter @click.command('reset-password', help='Reset the account password.') @@ -422,9 +424,77 @@ def convert_to_agent_apps(): click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) +@click.command('convert-to-workflow-chatbot-apps', help='Convert Basic Export Assistant to Chatbot Workflow App.') +def convert_to_workflow_chatbot_apps(): + """ + Convert Basic Export Assistant to Chatbot Workflow App. + """ + click.echo(click.style('Start convert to workflow chatbot apps.', fg='green')) + + proceeded_app_ids = [] + workflow_converter = WorkflowConverter() + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id FROM apps a +LEFT JOIN app_model_configs am ON a.app_model_config_id=am.id +WHERE a.mode = 'chat' AND am.prompt_type='advanced' ORDER BY a.created_at DESC LIMIT 1000""" + + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + print(app_id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.query(App).filter(App.id == app_id).first() + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo('Converting app: {}'.format(app.id)) + + try: + # get workspace of app + tenant = db.session.query(Tenant).filter(Tenant.id == app.tenant_id).first() + if not tenant: + click.echo(click.style('Tenant not found: {}'.format(app.tenant_id), fg='red')) + continue + + # get workspace owner + tenant_account_join = db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.tenant_id == tenant.id, + TenantAccountJoin.role == 'owner' + ).first() + + if not tenant_account_join: + click.echo(click.style('Tenant owner not found: {}'.format(tenant.id), fg='red')) + continue + + # convert to workflow + workflow_converter.convert_to_workflow( + app_model=app, + account_id=tenant_account_join.account_id + ) + + click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + except Exception as e: + logging.exception('Convert app error: {}'.format(app.id)) + click.echo( + click.style('Convert app error: {} {}'.format(e.__class__.__name__, + str(e)), fg='red')) + + click.echo(click.style('Congratulations! Converted {} workflow chatbot apps.'.format(len(proceeded_app_ids)), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) app.cli.add_command(convert_to_agent_apps) + app.cli.add_command(convert_to_workflow_chatbot_apps) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 77bb81b0da0d9d..ea0c85427d4811 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -235,12 +235,15 @@ def _handle_response(self, application_generate_entity: ApplicationGenerateEntit logger.exception(e) raise e - def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ + def convert_from_app_model_config_dict(self, tenant_id: str, + app_model_config_dict: dict, + skip_check: bool = False) \ -> AppOrchestrationConfigEntity: """ Convert app model config dict to entity. :param tenant_id: tenant ID :param app_model_config_dict: app model config dict + :param skip_check: skip check :raises ProviderTokenNotInitError: provider token not init error :return: app orchestration config entity """ @@ -268,24 +271,28 @@ def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_di ) if model_credentials is None: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + if not skip_check: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + else: + model_credentials = {} - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=copy_app_model_config_dict['model']['name'], - model_type=ModelType.LLM - ) + if not skip_check: + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=copy_app_model_config_dict['model']['name'], + model_type=ModelType.LLM + ) - if provider_model is None: - model_name = copy_app_model_config_dict['model']['name'] - raise ValueError(f"Model {model_name} not exist.") + if provider_model is None: + model_name = copy_app_model_config_dict['model']['name'] + raise ValueError(f"Model {model_name} not exist.") - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = copy_app_model_config_dict['model'].get('completion_params') @@ -309,7 +316,7 @@ def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_di model_credentials ) - if not model_schema: + if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") properties['model_config'] = ModelConfigEntity( diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 667940f184eb52..f5ea4d1eb0f8e0 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -15,7 +15,7 @@ class ModelConfigEntity(BaseModel): """ provider: str model: str - model_schema: AIModelEntity + model_schema: Optional[AIModelEntity] = None mode: str provider_model_bundle: ProviderModelBundle credentials: dict[str, Any] = {} diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 31df58a583b526..1d3cbe2e0e8fb5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,12 +13,11 @@ ) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.utils import helper +from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db -from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, ChatbotAppEngine from models.workflow import Workflow, WorkflowType @@ -29,7 +28,7 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + def convert_to_workflow(self, app_model: App, account_id: str) -> Workflow: """ Convert to workflow mode @@ -40,7 +39,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: - completion app (for migration) :param app_model: App instance - :param account: Account instance + :param account_id: Account ID :return: workflow instance """ # get new app mode @@ -53,7 +52,8 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: application_manager = ApplicationManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config.to_dict() + app_model_config_dict=app_model_config.to_dict(), + skip_check=True ) # init workflow graph @@ -122,7 +122,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), - created_by=account.id + created_by=account_id ) db.session.add(workflow) @@ -130,6 +130,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: # create new app model config record new_app_model_config = app_model_config.copy() + new_app_model_config.id = None new_app_model_config.external_data_tools = '' new_app_model_config.model = '' new_app_model_config.user_input_form = '' @@ -147,6 +148,9 @@ def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: db.session.add(new_app_model_config) db.session.commit() + app_model.app_model_config_id = new_app_model_config.id + db.session.commit() + return workflow def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: @@ -161,7 +165,7 @@ def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: "data": { "title": "START", "type": NodeType.START.value, - "variables": [helper.dump_model(v) for v in variables] + "variables": [jsonable_encoder(v) for v in variables] } } @@ -369,7 +373,10 @@ def _convert_to_llm_node(self, new_app_mode: AppMode, ] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template - prompts = [helper.dump_model(m) for m in advanced_chat_prompt_template.messages] \ + prompts = [{ + "role": m.role.value, + "text": m.text + } for m in advanced_chat_prompt_template.messages] \ if advanced_chat_prompt_template else [] # Completion Model else: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0cb398225d49b4..bd88f3cbe2efe8 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -79,6 +79,6 @@ def chatbot_convert_to_workflow(self, app_model: App, account: Account) -> Workf # convert to workflow mode workflow_converter = WorkflowConverter() - workflow = workflow_converter.convert_to_workflow(app_model=app_model, account=account) + workflow = workflow_converter.convert_to_workflow(app_model=app_model, account_id=account.id) return workflow diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index ee9e5eb2fa039e..d4edc73410eac3 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -41,6 +41,8 @@ def test__convert_to_start_node(default_variables): result = WorkflowConverter()._convert_to_start_node(default_variables) # assert + assert isinstance(result["data"]["variables"][0]["type"], str) + assert result["data"]["variables"][0]["type"] == "text-input" assert result["data"]["variables"][0]["variable"] == "text-input" assert result["data"]["variables"][1]["variable"] == "paragraph" assert result["data"]["variables"][2]["variable"] == "select" From 748aa22ee2e1deec036378d664bc7d6652886c4e Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:02:28 +0800 Subject: [PATCH 015/160] add manual convert logic --- api/commands.py | 81 +----------- api/controllers/console/app/workflow.py | 8 +- .../versions/b289e2408ee2_add_workflow.py | 2 + api/models/model.py | 1 + api/models/workflow.py | 78 +++++++++++ api/services/workflow/workflow_converter.py | 123 +++++++++++++----- api/services/workflow_service.py | 29 +++-- 7 files changed, 198 insertions(+), 124 deletions(-) diff --git a/api/commands.py b/api/commands.py index 73d2150de23df9..e376d222c65329 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,6 +1,5 @@ import base64 import json -import logging import secrets import click @@ -13,12 +12,11 @@ from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant, TenantAccountJoin +from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel -from services.workflow.workflow_converter import WorkflowConverter @click.command('reset-password', help='Reset the account password.') @@ -384,10 +382,11 @@ def convert_to_agent_apps(): while True: # fetch first 1000 apps sql_query = """SELECT a.id AS id FROM apps a -INNER JOIN app_model_configs am ON a.app_model_config_id=am.id -WHERE a.mode = 'chat' AND am.agent_mode is not null -and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') -and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + INNER JOIN app_model_configs am ON a.app_model_config_id=am.id + WHERE a.mode = 'chat' AND am.agent_mode is not null + and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') + and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query)) @@ -424,77 +423,9 @@ def convert_to_agent_apps(): click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) -@click.command('convert-to-workflow-chatbot-apps', help='Convert Basic Export Assistant to Chatbot Workflow App.') -def convert_to_workflow_chatbot_apps(): - """ - Convert Basic Export Assistant to Chatbot Workflow App. - """ - click.echo(click.style('Start convert to workflow chatbot apps.', fg='green')) - - proceeded_app_ids = [] - workflow_converter = WorkflowConverter() - - while True: - # fetch first 1000 apps - sql_query = """SELECT a.id FROM apps a -LEFT JOIN app_model_configs am ON a.app_model_config_id=am.id -WHERE a.mode = 'chat' AND am.prompt_type='advanced' ORDER BY a.created_at DESC LIMIT 1000""" - - with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query)) - - apps = [] - for i in rs: - app_id = str(i.id) - print(app_id) - if app_id not in proceeded_app_ids: - proceeded_app_ids.append(app_id) - app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) - - if len(apps) == 0: - break - - for app in apps: - click.echo('Converting app: {}'.format(app.id)) - - try: - # get workspace of app - tenant = db.session.query(Tenant).filter(Tenant.id == app.tenant_id).first() - if not tenant: - click.echo(click.style('Tenant not found: {}'.format(app.tenant_id), fg='red')) - continue - - # get workspace owner - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role == 'owner' - ).first() - - if not tenant_account_join: - click.echo(click.style('Tenant owner not found: {}'.format(tenant.id), fg='red')) - continue - - # convert to workflow - workflow_converter.convert_to_workflow( - app_model=app, - account_id=tenant_account_join.account_id - ) - - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) - except Exception as e: - logging.exception('Convert app error: {}'.format(app.id)) - click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) - - click.echo(click.style('Congratulations! Converted {} workflow chatbot apps.'.format(len(proceeded_app_ids)), fg='green')) - - def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) app.cli.add_command(convert_to_agent_apps) - app.cli.add_command(convert_to_workflow_chatbot_apps) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 7663e22580aff9..dc1b7edcaf35ec 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -69,15 +69,15 @@ class ConvertToWorkflowApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) - @marshal_with(workflow_fields) + @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model: App): """ - Convert basic mode of chatbot app to workflow + Convert basic mode of chatbot app(expert mode) to workflow mode + Convert Completion App to Workflow App """ # convert to workflow mode workflow_service = WorkflowService() - workflow = workflow_service.chatbot_convert_to_workflow( + workflow = workflow_service.convert_to_workflow( app_model=app_model, account=current_user ) diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index e9cd2caf3ae3c5..9e04fef2888f94 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -53,6 +53,7 @@ def upgrade(): sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), sa.Column('execution_metadata', sa.Text(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('finished_at', sa.DateTime(), nullable=True), sa.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey') @@ -80,6 +81,7 @@ def upgrade(): sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), sa.Column('currency', sa.String(length=255), nullable=True), sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('finished_at', sa.DateTime(), nullable=True), diff --git a/api/models/model.py b/api/models/model.py index 58e29cd21c6634..1e66fd6c889884 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -28,6 +28,7 @@ class DifySetup(db.Model): class AppMode(Enum): + COMPLETION = 'completion' WORKFLOW = 'workflow' CHAT = 'chat' AGENT = 'agent' diff --git a/api/models/workflow.py b/api/models/workflow.py index 95805e787129d5..251f33b0c08d8c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,6 +7,27 @@ from models.account import Account +class CreatedByRole(Enum): + """ + Created By Role Enum + """ + ACCOUNT = 'account' + END_USER = 'end_user' + + @classmethod + def value_of(cls, value: str) -> 'CreatedByRole': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid created by role value {value}') + + class WorkflowType(Enum): """ Workflow Type Enum @@ -99,6 +120,49 @@ def updated_by_account(self): return Account.query.get(self.updated_by) +class WorkflowRunTriggeredFrom(Enum): + """ + Workflow Run Triggered From Enum + """ + DEBUGGING = 'debugging' + APP_RUN = 'app-run' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow run triggered from value {value}') + + +class WorkflowRunStatus(Enum): + """ + Workflow Run Status Enum + """ + RUNNING = 'running' + SUCCEEDED = 'succeeded' + FAILED = 'failed' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowRunStatus': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow run status value {value}') + + class WorkflowRun(db.Model): """ Workflow Run @@ -128,6 +192,12 @@ class WorkflowRun(db.Model): - total_price (decimal) `optional` Total cost - currency (string) `optional` Currency, such as USD / RMB - total_steps (int) Total steps (redundant), default 0 + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + - created_by (uuid) Runner ID - created_at (timestamp) Run time - finished_at (timestamp) End time @@ -157,6 +227,7 @@ class WorkflowRun(db.Model): total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255)) total_steps = db.Column(db.Integer, server_default=db.text('0')) + created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) finished_at = db.Column(db.DateTime) @@ -208,6 +279,12 @@ class WorkflowNodeExecution(db.Model): - currency (string) `optional` Currency, such as USD / RMB - created_at (timestamp) Run time + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + - created_by (uuid) Runner ID - finished_at (timestamp) End time """ @@ -240,6 +317,7 @@ class WorkflowNodeExecution(db.Model): elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) execution_metadata = db.Column(db.Text) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) finished_at = db.Column(db.DateTime) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1d3cbe2e0e8fb5..bb300d1a7797af 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -17,9 +17,11 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType +from events.app_event import app_was_created from extensions.ext_database import db +from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode, ChatbotAppEngine, AppModelConfig, Site from models.workflow import Workflow, WorkflowType @@ -28,26 +30,99 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, account_id: str) -> Workflow: + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ - Convert to workflow mode + Convert app to workflow - basic mode of chatbot app - - advanced mode of assistant app (for migration) + - advanced mode of assistant app - - completion app (for migration) + - completion app :param app_model: App instance + :param account: Account + :return: new App instance + """ + # get original app config + app_model_config = app_model.app_model_config + + # convert app model config + workflow = self.convert_app_model_config_to_workflow( + app_model=app_model, + app_model_config=app_model_config, + account_id=account.id + ) + + # create new app + new_app = App() + new_app.tenant_id = app_model.tenant_id + new_app.name = app_model.name + '(workflow)' + new_app.mode = AppMode.CHAT.value \ + if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.icon = app_model.icon + new_app.icon_background = app_model.icon_background + new_app.enable_site = app_model.enable_site + new_app.enable_api = app_model.enable_api + new_app.api_rpm = app_model.api_rpm + new_app.api_rph = app_model.api_rph + new_app.is_demo = False + new_app.is_public = app_model.is_public + db.session.add(new_app) + db.session.flush() + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.id = None + new_app_model_config.app_id = new_app.id + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ + if app_model.mode == AppMode.CHAT.value else ChatbotAppEngine.NORMAL.value + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.flush() + + new_app.app_model_config_id = new_app_model_config.id + db.session.commit() + + site = Site( + app_id=new_app.id, + title=new_app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() + + app_was_created.send(new_app) + + return new_app + + def convert_app_model_config_to_workflow(self, app_model: App, + app_model_config: AppModelConfig, + account_id: str) -> Workflow: + """ + Convert app model config to workflow mode + :param app_model: App instance + :param app_model_config: AppModelConfig instance :param account_id: Account ID - :return: workflow instance + :return: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) - # get original app config - app_model_config = app_model.app_model_config - # convert app model config application_manager = ApplicationManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( @@ -122,33 +197,11 @@ def convert_to_workflow(self, app_model: App, account_id: str) -> Workflow: type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), - created_by=account_id + created_by=account_id, + created_at=app_model_config.created_at ) db.session.add(workflow) - db.session.flush() - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if new_app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.commit() - - app_model.app_model_config_id = new_app_model_config.id db.session.commit() return workflow @@ -469,7 +522,7 @@ def _convert_to_end_node(self, app_model: App) -> dict: "type": NodeType.END.value, } } - elif app_model.mode == "completion": + elif app_model.mode == AppMode.COMPLETION.value: # for original completion app return { "id": "end", @@ -516,7 +569,7 @@ def _get_new_app_mode(self, app_model: App) -> AppMode: :param app_model: App instance :return: AppMode """ - if app_model.mode == "completion": + if app_model.mode == AppMode.COMPLETION.value: return AppMode.WORKFLOW else: return AppMode.value_of(app_model.mode) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index bd88f3cbe2efe8..2d9342ffc96997 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,7 @@ from extensions.ext_database import db from models.account import Account -from models.model import App, ChatbotAppEngine +from models.model import App, ChatbotAppEngine, AppMode from models.workflow import Workflow, WorkflowType from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter @@ -65,20 +65,29 @@ def get_default_block_configs(self) -> dict: # return default block config return default_block_configs - def chatbot_convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ - basic mode of chatbot app to workflow + Basic mode of chatbot app(expert mode) to workflow + Completion App to Workflow App :param app_model: App instance :param account: Account instance :return: """ - # check if chatbot app is in basic mode - if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: - raise ValueError('Chatbot app already in workflow mode') - - # convert to workflow mode + # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - workflow = workflow_converter.convert_to_workflow(app_model=app_model, account_id=account.id) - return workflow + if app_model.mode == AppMode.CHAT.value: + # check if chatbot app is in basic mode + if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: + raise ValueError('Chatbot app already in workflow mode') + elif app_model.mode != AppMode.COMPLETION.value: + raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + + # convert to workflow + new_app = workflow_converter.convert_to_workflow( + app_model=app_model, + account=account + ) + + return new_app From 97c4733e7928b09b33e18c5f3f54856890c78c1f Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:02:38 +0800 Subject: [PATCH 016/160] lint fix --- api/services/workflow/workflow_converter.py | 2 +- api/services/workflow_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index bb300d1a7797af..c6f0bed008986b 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -21,7 +21,7 @@ from extensions.ext_database import db from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, ChatbotAppEngine, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig, ChatbotAppEngine, Site from models.workflow import Workflow, WorkflowType diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2d9342ffc96997..4f7262b7d6d993 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,7 @@ from extensions.ext_database import db from models.account import Account -from models.model import App, ChatbotAppEngine, AppMode +from models.model import App, AppMode, ChatbotAppEngine from models.workflow import Workflow, WorkflowType from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter From fce20e483cf4cc4eadd8f3386f4478ac5a50bbfd Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:30:36 +0800 Subject: [PATCH 017/160] restore completion app --- api/controllers/console/app/app.py | 2 +- api/controllers/console/app/completion.py | 4 +- api/controllers/console/app/conversation.py | 4 +- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/explore/message.py | 47 +++++++++++++++ api/controllers/web/message.py | 47 +++++++++++++++ api/core/app_runner/app_runner.py | 19 ++++-- api/core/prompt/prompt_transform.py | 7 +-- api/core/prompt/simple_prompt_transform.py | 38 +++++++----- api/services/app_model_config_service.py | 18 ++++++ api/services/completion_service.py | 60 ++++++++++++++++++- api/services/errors/__init__.py | 2 +- api/services/errors/app.py | 2 + .../prompt/test_simple_prompt_transform.py | 2 + 14 files changed, 224 insertions(+), 30 deletions(-) create mode 100644 api/services/errors/app.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index cf505bedb8aedc..93dc1ca34a5f16 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -80,7 +80,7 @@ def post(self): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') + parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 11fdba177d6734..e62475308faf04 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -37,7 +37,7 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') @@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): account = flask_login.current_user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index daf96411211abd..b808d62eb017b0 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -29,7 +29,7 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index ea4d5971127def..e3a5112200568b 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def get(self, app_model): account = current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index bef26b4d994d8f..47af28425fa896 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -12,6 +12,7 @@ import services from controllers.console import api from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -23,10 +24,13 @@ NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -72,6 +76,48 @@ def post(self, installed_app, message_id): return {'result': 'success'} +class MessageMoreLikeThisApi(InstalledAppResource): + def get(self, installed_app, message_id): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this( + app_model=app_model, + user=current_user, + message_id=message_id, + invoke_from=InvokeFrom.EXPLORE, + streaming=streaming + ) + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -120,4 +166,5 @@ def get(self, installed_app, message_id): api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') +api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 5120f49c5ecf95..e03bdd63bb2a27 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -11,6 +11,7 @@ import services from controllers.web import api from controllers.web.error import ( + AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, CompletionRequestError, NotChatAppError, @@ -20,11 +21,14 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs.helper import TimestampField, uuid_value +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -109,6 +113,48 @@ def post(self, app_model, end_user, message_id): return {'result': 'success'} +class MessageMoreLikeThisApi(WebApiResource): + def get(self, app_model, end_user, message_id): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this( + app_model=app_model, + user=end_user, + message_id=message_id, + invoke_from=InvokeFrom.WEB_APP, + streaming=streaming + ) + + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -156,4 +202,5 @@ def get(self, app_model, end_user, message_id): api.add_resource(MessageListApi, '/messages') api.add_resource(MessageFeedbackApi, '/messages//feedbacks') +api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index c6f6268a7a73dc..231530ef086110 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -22,8 +22,9 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import App, Message, MessageAnnotation +from models.model import App, Message, MessageAnnotation, AppMode class AppRunner: @@ -140,11 +141,11 @@ def organize_prompt_messages(self, app_record: App, :param memory: memory :return: """ - prompt_transform = SimplePromptTransform() - # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( + app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -154,7 +155,17 @@ def organize_prompt_messages(self, app_record: App, model_config=model_config ) else: - raise NotImplementedError("Advanced prompt is not supported yet.") + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query if query else '', + files=files, + context=context, + memory=memory, + model_config=model_config + ) + stop = model_config.stop return prompt_messages, stop diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9596976b6eff77..9c554140b7b16b 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -11,10 +11,9 @@ class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> list[PromptMessage]: - if memory: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) - prompt_messages.extend(histories) + rest_tokens = self._calculate_rest_token(prompt_messages, model_config) + histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + prompt_messages.extend(histories) return prompt_messages diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 2f98fbcae8a1c1..a929416be4e0af 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform): """ def get_prompt(self, + app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, @@ -58,6 +59,7 @@ def get_prompt(self, model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( + app_mode=app_mode, pre_prompt=prompt_template_entity.simple_prompt_template, inputs=inputs, query=query, @@ -68,6 +70,7 @@ def get_prompt(self, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( + app_mode=app_mode, pre_prompt=prompt_template_entity.simple_prompt_template, inputs=inputs, query=query, @@ -154,7 +157,8 @@ def get_prompt_template(self, app_mode: AppMode, "prompt_rules": prompt_rules } - def _get_chat_model_prompt_messages(self, pre_prompt: str, + def _get_chat_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, inputs: dict, query: str, context: Optional[str], @@ -166,7 +170,7 @@ def _get_chat_model_prompt_messages(self, pre_prompt: str, # get prompt prompt, _ = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -175,19 +179,25 @@ def _get_chat_model_prompt_messages(self, pre_prompt: str, ) if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) + if query: + prompt_messages.append(SystemPromptMessage(content=prompt)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) - prompt_messages = self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) + if memory: + prompt_messages = self._append_chat_histories( + memory=memory, + prompt_messages=prompt_messages, + model_config=model_config + ) - prompt_messages.append(self.get_last_user_message(query, files)) + if query: + prompt_messages.append(self.get_last_user_message(query, files)) return prompt_messages, None - def _get_completion_model_prompt_messages(self, pre_prompt: str, + def _get_completion_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, inputs: dict, query: str, context: Optional[str], @@ -197,7 +207,7 @@ def _get_completion_model_prompt_messages(self, pre_prompt: str, -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -220,7 +230,7 @@ def _get_completion_model_prompt_messages(self, pre_prompt: str, # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -289,13 +299,13 @@ def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str is_baichuan = True if is_baichuan: - if app_mode == AppMode.WORKFLOW: + if app_mode == AppMode.COMPLETION: return 'baichuan_completion' else: return 'baichuan_chat' # common - if app_mode == AppMode.WORKFLOW: + if app_mode == AppMode.COMPLETION: return 'common_completion' else: return 'common_chat' diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index aa8cd73ea75daf..34b6d62d51a47c 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -316,6 +316,9 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, if "tool_parameters" not in tool: raise ValueError("tool_parameters is required in agent_mode.tools") + # dataset_query_variable + cls.is_dataset_query_variable_valid(config, app_mode) + # advanced prompt validation cls.is_advanced_prompt_valid(config, app_mode) @@ -441,6 +444,21 @@ def is_external_data_tools_valid(cls, tenant_id: str, config: dict): config=config ) + @classmethod + def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: + # Only check when mode is completion + if mode != 'completion': + return + + agent_mode = config.get("agent_mode", {}) + tools = agent_mode.get("tools", []) + dataset_exists = "dataset" in str(tools) + + dataset_query_variable = config.get("dataset_query_variable") + + if dataset_exists and not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 5599c60113c3b5..cbfbe9ef416b63 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -8,10 +8,12 @@ from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser +from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message from services.app_model_config_service import AppModelConfigService +from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError +from services.errors.message import MessageNotExistsError class CompletionService: @@ -155,6 +157,62 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, } ) + @classmethod + def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], + message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ + -> Union[dict, Generator]: + if not user: + raise ValueError('user cannot be None') + + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + model_dict = app_model_config.model_dict + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + app_model_config.model = json.dumps(model_dict) + + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_objs = message_file_parser.transform_message_files( + message.files, app_model_config + ) + + application_manager = ApplicationManager() + return application_manager.generate( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_model_config_id=app_model_config.id, + app_model_config_dict=app_model_config.to_dict(), + app_model_config_override=True, + user=user, + invoke_from=invoke_from, + inputs=message.inputs, + query=message.query, + files=file_objs, + conversation=None, + stream=streaming, + extras={ + "auto_generate_conversation_name": False + } + ) + @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index a44c190cbc1d28..5804f599fe63bf 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'completion', 'audio', 'file' + 'app', 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/app.py b/api/services/errors/app.py new file mode 100644 index 00000000000000..7c4ca99c2ae869 --- /dev/null +++ b/api/services/errors/app.py @@ -0,0 +1,2 @@ +class MoreLikeThisDisabledError(Exception): + pass diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c174983e384e99..a95a6dc52f524c 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages(): context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, pre_prompt=pre_prompt, inputs=inputs, query=query, @@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages(): context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, pre_prompt=pre_prompt, inputs=inputs, query=query, From 98cb17e79e7c5bf827292889ed8f496b7362453a Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:30:44 +0800 Subject: [PATCH 018/160] lint fix --- api/core/app_runner/app_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index 231530ef086110..95f2f568dcfad2 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -24,7 +24,7 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import App, Message, MessageAnnotation, AppMode +from models.model import App, AppMode, Message, MessageAnnotation class AppRunner: From 34ed5e428cdf2f116156033e3ae3dfa33b53651a Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:55:39 +0800 Subject: [PATCH 019/160] fix bugs --- api/core/prompt/advanced_prompt_transform.py | 36 +++++++++++++------ .../prompt/test_advanced_prompt_transform.py | 1 + 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0ed9ec352cef96..7519971ce75292 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -39,6 +39,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, prompt_messages = self._get_completion_model_prompt_messages( prompt_template_entity=prompt_template_entity, inputs=inputs, + query=query, files=files, context=context, memory=memory, @@ -60,6 +61,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, def _get_completion_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, + query: Optional[str], files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], @@ -86,6 +88,9 @@ def _get_completion_model_prompt_messages(self, model_config=model_config ) + if query: + prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + prompt = prompt_template.format( prompt_inputs ) @@ -147,21 +152,30 @@ def _get_chat_model_prompt_messages(self, else: prompt_messages.append(UserPromptMessage(content=query)) elif files: - # get last message - last_message = prompt_messages[-1] if prompt_messages else None - if last_message and last_message.role == PromptMessageRole.USER: - # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - last_message.content = prompt_message_contents + if not query: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: prompt_message_contents.append(file.prompt_message_content) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + elif query: + prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages @@ -210,4 +224,4 @@ def _set_histories_variable(self, memory: TokenBufferMemory, else: prompt_inputs['#histories#'] = '' - return prompt_inputs + return prompt_inputs diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 65a160a8e592de..95f1e30b449c72 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -50,6 +50,7 @@ def test__get_completion_model_prompt_messages(): prompt_messages = prompt_transform._get_completion_model_prompt_messages( prompt_template_entity=prompt_template_entity, inputs=inputs, + query=None, files=files, context=context, memory=memory, From 77f04603b3633c809e03d6e4b7b4d79d18d6ce59 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 22:11:20 +0800 Subject: [PATCH 020/160] fix bugs --- api/core/prompt/simple_prompt_transform.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a929416be4e0af..fcae0dc78643bb 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -178,11 +178,8 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, context=context ) - if prompt: - if query: - prompt_messages.append(SystemPromptMessage(content=prompt)) - else: - prompt_messages.append(UserPromptMessage(content=prompt)) + if prompt and query: + prompt_messages.append(SystemPromptMessage(content=prompt)) if memory: prompt_messages = self._append_chat_histories( @@ -193,6 +190,8 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, if query: prompt_messages.append(self.get_last_user_message(query, files)) + else: + prompt_messages.append(self.get_last_user_message(prompt, files)) return prompt_messages, None From a9192bc1c63352fbf3134100ca9db355fa02dbe0 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 26 Feb 2024 12:43:46 +0800 Subject: [PATCH 021/160] make recommended app list api public --- .../console/explore/recommended_app.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index fd90be03b16743..8b8fe349ed1c76 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, fields, marshal_with +from flask_restful import Resource, fields, marshal_with, reqparse from sqlalchemy import and_ from constants.languages import languages @@ -28,9 +28,6 @@ 'category': fields.String, 'position': fields.Integer, 'is_listed': fields.Boolean, - 'install_count': fields.Integer, - 'installed': fields.Boolean, - 'editable': fields.Boolean, 'is_agent': fields.Boolean } @@ -41,11 +38,19 @@ class RecommendedAppListApi(Resource): - @login_required - @account_initialization_required @marshal_with(recommended_app_list_fields) def get(self): - language_prefix = current_user.interface_language if current_user.interface_language else languages[0] + # language args + parser = reqparse.RequestParser() + parser.add_argument('language', type=str, location='args') + args = parser.parse_args() + + if args.get('language') and args.get('language') in languages: + language_prefix = args.get('language') + elif current_user and current_user.interface_language: + language_prefix = current_user.interface_language + else: + language_prefix = languages[0] recommended_apps = db.session.query(RecommendedApp).filter( RecommendedApp.is_listed == True, @@ -53,16 +58,8 @@ def get(self): ).all() categories = set() - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) recommended_apps_result = [] for recommended_app in recommended_apps: - installed = db.session.query(InstalledApp).filter( - and_( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id == current_user.current_tenant_id - ) - ).first() is not None - app = recommended_app.app if not app or not app.is_public: continue @@ -81,9 +78,6 @@ def get(self): 'category': recommended_app.category, 'position': recommended_app.position, 'is_listed': recommended_app.is_listed, - 'install_count': recommended_app.install_count, - 'installed': installed, - 'editable': current_user.role in ['owner', 'admin'], "is_agent": app.is_agent } recommended_apps_result.append(recommended_app_result) @@ -114,8 +108,6 @@ class RecommendedAppApi(Resource): 'app_model_config': fields.Nested(model_config_fields), } - @login_required - @account_initialization_required @marshal_with(app_simple_detail_fields) def get(self, app_id): app_id = str(app_id) From 78afba49bf336542bec774bc9b859d57c4556f7a Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 26 Feb 2024 12:44:21 +0800 Subject: [PATCH 022/160] lint fix --- api/controllers/console/explore/recommended_app.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 8b8fe349ed1c76..6ba04d603ac479 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,15 +1,11 @@ from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse -from sqlalchemy import and_ from constants.languages import languages from controllers.console import api from controllers.console.app.error import AppNotFoundError -from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.login import login_required -from models.model import App, InstalledApp, RecommendedApp -from services.account_service import TenantService +from models.model import App, RecommendedApp app_fields = { 'id': fields.String, From 27ba5a0bce66879969e6da9d2554b86815fdcb76 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:23:01 +0800 Subject: [PATCH 023/160] refactor app mode add app import and export --- api/constants/languages.py | 509 ------------------ api/constants/model_template.py | 97 ++-- api/controllers/console/app/app.py | 239 +++++--- api/controllers/console/app/workflow.py | 11 +- api/controllers/console/app/wraps.py | 18 +- .../console/explore/installed_app.py | 3 +- .../console/explore/recommended_app.py | 64 ++- api/core/provider_manager.py | 2 +- api/fields/app_fields.py | 12 - api/fields/installed_app_fields.py | 3 +- .../versions/b289e2408ee2_add_workflow.py | 2 - ...998d4d_set_model_config_column_nullable.py | 70 +++ api/models/model.py | 53 +- api/services/workflow/workflow_converter.py | 4 +- api/services/workflow_service.py | 43 +- 15 files changed, 370 insertions(+), 760 deletions(-) create mode 100644 api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py diff --git a/api/constants/languages.py b/api/constants/languages.py index 0ae69d77d20283..0147dd8d705104 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -91,512 +91,3 @@ def supported_language(lang): } ], } - -demo_model_templates = { - 'en-US': [ - { - 'name': 'Translation Assistant', - 'icon': '', - 'icon_background': '', - 'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "Please translate the following text into {{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "Target Language", - "description": "The language you want to translate into.", - "type": "select", - "default": "Chinese", - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Target Language", - "variable": "target_language", - "description": "The language you want to translate into.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - }, { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'AI Front-end Interviewer', - 'icon': '', - 'icon_background': '', - 'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', - 'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', - suggested_questions=None, - pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], - 'zh-Hans': [ - { - 'name': '翻译助手', - 'icon': '', - 'icon_background': '', - 'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "请将以下文本翻译为{{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "目标语言", - "description": "翻译的目标语言", - "type": "select", - "default": "中文", - "options": [ - "中文", - "英文", - "日语", - "法语", - "俄语", - "德语", - "西班牙语", - "韩语", - "意大利语", - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "目标语言", - "variable": "target_language", - "description": "翻译的目标语言", - "default": "中文", - "required": True, - 'options': [ - "中文", - "英文", - "日语", - "法语", - "俄语", - "德语", - "西班牙语", - "韩语", - "意大利语", - ] - } - }, { - "paragraph": { - "label": "文本内容", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'AI 前端面试官', - 'icon': '', - 'icon_background': '', - 'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', - 'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', - suggested_questions=None, - pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], - 'uk-UA': [ - { - "name": "Помічник перекладу", - "icon": "", - "icon_background": "", - "description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.", - "mode": "completion", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo-instruct", - configs={ - "prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n", - "prompt_variables": [ - { - "key": "target_language", - "name": "Цільова мова", - "description": "Мова, на яку ви хочете перекласти.", - "type": "select", - "default": "Ukrainian", - "options": [ - "Chinese", - "English", - "Japanese", - "French", - "Russian", - "German", - "Spanish", - "Korean", - "Italian", - ], - }, - ], - "completion_params": { - "max_token": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }, - opening_statement="", - suggested_questions=None, - pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Цільова мова", - "variable": "target_language", - "description": "Мова, на яку ви хочете перекласти.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - }, { - "paragraph": { - "label": "Запит", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - "name": "AI інтерв’юер фронтенду", - "icon": "", - "icon_background": "", - "description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.", - "mode": "chat", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo", - configs={ - "introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", - "prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", - "prompt_variables": [], - "completion_params": { - "max_token": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }, - opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", - suggested_questions=None, - pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=None - ), - } - ], - 'vi-VN': [ - { - 'name': 'Trợ lý dịch thuật', - 'icon': '', - 'icon_background': '', - 'description': 'Trình dịch đa ngôn ngữ cung cấp khả năng dịch bằng nhiều ngôn ngữ, dịch thông tin đầu vào của người dùng sang ngôn ngữ họ cần.', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "Hãy dịch đoạn văn bản sau sang ngôn ngữ {{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "Ngôn ngữ đích", - "description": "Ngôn ngữ bạn muốn dịch sang.", - "type": "select", - "default": "Vietnamese", - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - 'Vietnamese', - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="Hãy dịch đoạn văn bản sau sang {{target_language}}:\n{{query}}\ndịch:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Ngôn ngữ đích", - "variable": "target_language", - "description": "Ngôn ngữ bạn muốn dịch sang.", - "default": "Vietnamese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - 'Vietnamese', - ] - } - }, { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'Phỏng vấn front-end AI', - 'icon': '', - 'icon_background': '', - 'description': 'Một người phỏng vấn front-end mô phỏng để kiểm tra mức độ kỹ năng phát triển front-end thông qua việc đặt câu hỏi.', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': 'Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', - 'prompt_template': "Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', - suggested_questions=None, - pre_prompt="Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], -} diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c22306ac87b17b..ca0b7549897bc0 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,50 +1,25 @@ -import json +from models.model import AppMode -model_templates = { +default_app_templates = { # workflow default mode - 'workflow_default': { + AppMode.WORKFLOW: { 'app': { - 'mode': 'workflow', + 'mode': AppMode.WORKFLOW.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True }, - 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {} - } + 'model_config': {} }, # chat default mode - 'chat_default': { + AppMode.CHAT: { 'app': { - 'mode': 'chat', + 'mode': AppMode.CHAT.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True }, 'model_config': { - 'provider': 'openai', - 'model_id': 'gpt-4', - 'configs': { - 'prompt_template': '', - 'prompt_variables': [], - 'completion_params': { - 'max_token': 512, - 'temperature': 1, - 'top_p': 1, - 'presence_penalty': 0, - 'frequency_penalty': 0, - } - }, - 'model': json.dumps({ + 'model': { "provider": "openai", "name": "gpt-4", "mode": "chat", @@ -55,36 +30,42 @@ "presence_penalty": 0, "frequency_penalty": 0 } - }) + } } }, - # agent default mode - 'agent_default': { + # advanced-chat default mode + AppMode.ADVANCED_CHAT: { 'app': { - 'mode': 'agent', + 'mode': AppMode.ADVANCED_CHAT.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True }, 'model_config': { - 'provider': 'openai', - 'model_id': 'gpt-4', - 'configs': { - 'prompt_template': '', - 'prompt_variables': [], - 'completion_params': { - 'max_token': 512, - 'temperature': 1, - 'top_p': 1, - 'presence_penalty': 0, - 'frequency_penalty': 0, + 'model': { + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 } - }, - 'model': json.dumps({ + } + } + }, + + # agent-chat default mode + AppMode.AGENT_CHAT: { + 'app': { + 'mode': AppMode.AGENT_CHAT.value, + 'enable_site': True, + 'enable_api': True + }, + 'model_config': { + 'model': { "provider": "openai", "name": "gpt-4", "mode": "chat", @@ -95,7 +76,7 @@ "presence_penalty": 0, "frequency_penalty": 0 } - }) + } } }, } diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 93dc1ca34a5f16..4c218bef1bc2f1 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,13 +1,15 @@ import json import logging from datetime import datetime +from typing import cast +import yaml from flask_login import current_user from flask_restful import Resource, abort, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden -from constants.languages import demo_model_templates, languages -from constants.model_template import model_templates +from constants.languages import languages +from constants.model_template import default_app_templates from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.wraps import get_app_model @@ -15,7 +17,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db @@ -28,10 +31,15 @@ from libs.login import login_required from models.model import App, AppModelConfig, Site, AppMode from services.app_model_config_service import AppModelConfigService +from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager from core.entities.application_entities import AgentToolEntity + +ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] + + class AppListApi(Resource): @setup_required @@ -43,7 +51,7 @@ def get(self): parser = reqparse.RequestParser() parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'completion', 'all'], default='all', location='args', required=False) + parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent', 'channel', 'all'], default='all', location='args', required=False) parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() @@ -52,15 +60,20 @@ def get(self): App.is_universal == False ] - if args['mode'] == 'completion': - filters.append(App.mode == 'completion') + if args['mode'] == 'workflow': + filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) elif args['mode'] == 'chat': - filters.append(App.mode == 'chat') + filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) + elif args['mode'] == 'agent': + filters.append(App.mode == AppMode.AGENT_CHAT.value) + elif args['mode'] == 'channel': + filters.append(App.mode == AppMode.CHANNEL.value) else: pass if 'name' in args and args['name']: - filters.append(App.name.ilike(f'%{args["name"]}%')) + name = args['name'][:30] + filters.append(App.name.ilike(f'%{name}%')) app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), @@ -80,10 +93,9 @@ def post(self): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json') + parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('model_config', type=dict, location='json') args = parser.parse_args() # The role of the current user in the ta table must be admin or owner @@ -141,15 +153,15 @@ def post(self): app_mode = AppMode.value_of(args['mode']) - model_config_template = model_templates[app_mode.value + '_default'] + app_template = default_app_templates[app_mode] - app = App(**model_config_template['app']) - app_model_config = AppModelConfig(**model_config_template['model_config']) - - if app_mode in [AppMode.CHAT, AppMode.AGENT]: + # get model config + default_model_config = app_template['model_config'] + if 'model' in default_model_config: # get model provider model_manager = ModelManager() + # get default model instance try: model_instance = model_manager.get_default_model_instance( tenant_id=current_user.current_tenant_id, @@ -159,10 +171,25 @@ def post(self): model_instance = None if model_instance: - model_dict = app_model_config.model_dict - model_dict['provider'] = model_instance.provider - model_dict['name'] = model_instance.model - app_model_config.model = json.dumps(model_dict) + if model_instance.model == default_model_config['model']['name']: + default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] + + default_model_config['model'] = json.dumps(default_model_dict) + + app = App(**app_template['app']) + app_model_config = AppModelConfig(**default_model_config) app.name = args['name'] app.mode = args['mode'] @@ -195,24 +222,95 @@ def post(self): app_was_created.send(app) return app, 201 - -class AppTemplateApi(Resource): +class AppImportApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(template_list_fields) - def get(self): - """Get app demo templates""" + @marshal_with(app_detail_fields) + @cloud_edition_billing_resource_check('apps') + def post(self): + """Import app""" + # The role of the current user in the ta table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument('name', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + try: + import_data = yaml.safe_load(args['data']) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML format in data argument.") + + app_data = import_data.get('app') + model_config_data = import_data.get('model_config') + workflow_graph = import_data.get('workflow_graph') + + if not app_data or not model_config_data: + raise ValueError("Missing app or model_config in data argument") + + app_mode = AppMode.value_of(app_data.get('mode')) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if not workflow_graph: + raise ValueError("Missing workflow_graph in data argument " + "when mode is advanced-chat or workflow") + + app = App( + enable_site=True, + enable_api=True, + is_demo=False, + api_rpm=0, + api_rph=0, + status='normal' + ) + + app.tenant_id = current_user.current_tenant_id + app.mode = app_data.get('mode') + app.name = args.get("name") if args.get("name") else app_data.get('name') + app.icon = args.get("icon") if args.get("icon") else app_data.get('icon') + app.icon_background = args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background') + + db.session.add(app) + db.session.commit() + + if workflow_graph: + workflow_service = WorkflowService() + draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, current_user) + published_workflow = workflow_service.publish_draft_workflow(app, current_user, draft_workflow) + model_config_data['workflow_id'] = published_workflow.id + + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + account = current_user - interface_language = account.interface_language - templates = demo_model_templates.get(interface_language) - if not templates: - templates = demo_model_templates.get(languages[0]) + site = Site( + app_id=app.id, + title=app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() - return {'data': templates} + app_was_created.send(app) + + return app, 201 class AppApi(Resource): @@ -283,6 +381,38 @@ def delete(self, app_model): return {'result': 'success'}, 204 +class AppExportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): + """Export app""" + app_model_config = app_model.app_model_config + + export_data = { + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": app_model.icon, + "icon_background": app_model.icon_background + }, + "model_config": app_model_config.to_dict(), + } + + if app_model_config.workflow_id: + export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + else: + # get draft workflow + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + export_data['workflow_graph'] = json.loads(workflow.graph) + + return { + "data": yaml.dump(export_data) + } + + class AppNameApi(Resource): @setup_required @login_required @@ -360,57 +490,10 @@ def post(self, app_model): return app_model -class AppCopy(Resource): - @staticmethod - def create_app_copy(app): - copy_app = App( - name=app.name + ' copy', - icon=app.icon, - icon_background=app.icon_background, - tenant_id=app.tenant_id, - mode=app.mode, - app_model_config_id=app.app_model_config_id, - enable_site=app.enable_site, - enable_api=app.enable_api, - api_rpm=app.api_rpm, - api_rph=app.api_rph - ) - return copy_app - - @staticmethod - def create_app_model_config_copy(app_config, copy_app_id): - copy_app_model_config = app_config.copy() - copy_app_model_config.app_id = copy_app_id - - return copy_app_model_config - - @setup_required - @login_required - @account_initialization_required - @get_app_model - @marshal_with(app_detail_fields) - def post(self, app_model): - copy_app = self.create_app_copy(app_model) - db.session.add(copy_app) - - app_config = db.session.query(AppModelConfig). \ - filter(AppModelConfig.app_id == app_model.id). \ - one_or_none() - - if app_config: - copy_app_model_config = self.create_app_model_config_copy(app_config, copy_app.id) - db.session.add(copy_app_model_config) - db.session.commit() - copy_app.app_model_config_id = copy_app_model_config.id - db.session.commit() - - return copy_app, 201 - - api.add_resource(AppListApi, '/apps') -api.add_resource(AppTemplateApi, '/app-templates') +api.add_resource(AppImportApi, '/apps/import') api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopy, '/apps//copy') +api.add_resource(AppExportApi, '/apps//export') api.add_resource(AppNameApi, '/apps//name') api.add_resource(AppIconApi, '/apps//icon') api.add_resource(AppSiteStatus, '/apps//site-enable') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index dc1b7edcaf35ec..6023d0ba45a9d4 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from controllers.console.wraps import account_initialization_required from fields.workflow_fields import workflow_fields from libs.login import current_user, login_required -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode from services.workflow_service import WorkflowService @@ -15,7 +15,7 @@ class DraftWorkflowApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_fields) def get(self, app_model: App): """ @@ -34,7 +34,7 @@ def get(self, app_model: App): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def post(self, app_model: App): """ Sync draft workflow @@ -55,7 +55,7 @@ class DefaultBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def get(self, app_model: App): """ Get default block config @@ -72,7 +72,8 @@ class ConvertToWorkflowApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model: App): """ - Convert basic mode of chatbot app(expert mode) to workflow mode + Convert basic mode of chatbot app to workflow mode + Convert expert mode of chatbot app to workflow mode Convert Completion App to Workflow App """ # convert to workflow mode diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 1c2c4cf5c7c970..d61ab6d6ae8f28 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,12 +5,11 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None, - app_engine: ChatbotAppEngine = None): + mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): @@ -32,6 +31,9 @@ def decorated_view(*args, **kwargs): raise AppNotFoundError() app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.CHANNEL: + raise AppNotFoundError() + if mode is not None: if isinstance(mode, list): modes = mode @@ -42,16 +44,6 @@ def decorated_view(*args, **kwargs): mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - if app_engine is not None: - if app_mode not in [AppMode.CHAT, AppMode.WORKFLOW]: - raise AppNotFoundError(f"App mode is not supported for {app_engine.value} app engine.") - - if app_mode == AppMode.CHAT: - # fetch current app model config - app_model_config = app_model.app_model_config - if not app_model_config or app_model_config.chatbot_app_engine != app_engine.value: - raise AppNotFoundError(f"{app_engine.value} app engine is not supported.") - kwargs['app_model'] = app_model return view_func(*args, **kwargs) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 920d9141ae1189..7d6231270f23d8 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -34,8 +34,7 @@ def get(self): 'is_pinned': installed_app.is_pinned, 'last_used_at': installed_app.last_used_at, 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id, - 'is_agent': installed_app.is_agent + 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id } for installed_app in installed_apps ] diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6ba04d603ac479..3c28980f5188de 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,3 +1,6 @@ +import json + +import yaml from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse @@ -6,6 +9,7 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from models.model import App, RecommendedApp +from services.workflow_service import WorkflowService app_fields = { 'id': fields.String, @@ -23,8 +27,7 @@ 'privacy_policy': fields.String, 'category': fields.String, 'position': fields.Integer, - 'is_listed': fields.Boolean, - 'is_agent': fields.Boolean + 'is_listed': fields.Boolean } recommended_app_list_fields = { @@ -73,8 +76,7 @@ def get(self): 'privacy_policy': site.privacy_policy, 'category': recommended_app.category, 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed, - "is_agent": app.is_agent + 'is_listed': recommended_app.is_listed } recommended_apps_result.append(recommended_app_result) @@ -84,27 +86,6 @@ def get(self): class RecommendedAppApi(Resource): - model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - } - - app_simple_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'mode': fields.String, - 'app_model_config': fields.Nested(model_config_fields), - } - - @marshal_with(app_simple_detail_fields) def get(self, app_id): app_id = str(app_id) @@ -118,11 +99,38 @@ def get(self, app_id): raise AppNotFoundError # get app detail - app = db.session.query(App).filter(App.id == app_id).first() - if not app or not app.is_public: + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: raise AppNotFoundError - return app + app_model_config = app_model.app_model_config + + export_data = { + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": app_model.icon, + "icon_background": app_model.icon_background + }, + "model_config": app_model_config.to_dict(), + } + + if app_model_config.workflow_id: + export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + else: + # get draft workflow + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + export_data['workflow_graph'] = json.loads(workflow.graph) + + return { + 'id': app_model.id, + 'name': app_model.name, + 'icon': app_model.icon, + 'icon_background': app_model.icon_background, + 'mode': app_model.mode, + 'export_data': yaml.dump(export_data) + } api.add_resource(RecommendedAppListApi, '/explore/apps') diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6e28247d38d56a..0db84d3b6959a6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -235,7 +235,7 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D if available_models: found = False for available_model in available_models: - if available_model.model == "gpt-3.5-turbo-1106": + if available_model.model == "gpt-4": default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index e6c1272086581d..75b68d24fcd449 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -42,14 +42,10 @@ 'id': fields.String, 'name': fields.String, 'mode': fields.String, - 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'created_at': TimestampField } @@ -67,12 +63,8 @@ 'id': fields.String, 'name': fields.String, 'mode': fields.String, - 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), 'created_at': TimestampField } @@ -122,10 +114,6 @@ 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_agent': fields.Boolean, - 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'site': fields.Nested(site_fields), 'api_base_url': fields.String, diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 821d3c0adef3ab..35cc5a64755eca 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -17,8 +17,7 @@ 'is_pinned': fields.Boolean, 'last_used_at': TimestampField, 'editable': fields.Boolean, - 'uninstallable': fields.Boolean, - 'is_agent': fields.Boolean, + 'uninstallable': fields.Boolean } installed_app_list_fields = { diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 9e04fef2888f94..7255b4b5fa6ba3 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -107,7 +107,6 @@ def upgrade(): batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) with op.batch_alter_table('messages', schema=None) as batch_op: @@ -123,7 +122,6 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('workflow_id') - batch_op.drop_column('chatbot_app_engine') with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.drop_index('workflow_version_idx') diff --git a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py new file mode 100644 index 00000000000000..c302e8b5302e5c --- /dev/null +++ b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py @@ -0,0 +1,70 @@ +"""set model config column nullable + +Revision ID: cc04d0998d4d +Revises: b289e2408ee2 +Create Date: 2024-02-27 03:47:47.376325 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'cc04d0998d4d' +down_revision = 'b289e2408ee2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('configs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.alter_column('api_rpm', + existing_type=sa.Integer(), + server_default='0', + nullable=False) + + batch_op.alter_column('api_rph', + existing_type=sa.Integer(), + server_default='0', + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.alter_column('api_rpm', + existing_type=sa.Integer(), + server_default=None, + nullable=False) + + batch_op.alter_column('api_rph', + existing_type=sa.Integer(), + server_default=None, + nullable=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.alter_column('configs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 1e66fd6c889884..713d8da5775bf7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,7 +31,9 @@ class AppMode(Enum): COMPLETION = 'completion' WORKFLOW = 'workflow' CHAT = 'chat' - AGENT = 'agent' + ADVANCED_CHAT = 'advanced-chat' + AGENT_CHAT = 'agent-chat' + CHANNEL = 'channel' @classmethod def value_of(cls, value: str) -> 'AppMode': @@ -64,8 +66,8 @@ class App(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False) - api_rph = db.Column(db.Integer, nullable=False) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -92,19 +94,7 @@ def api_base_url(self): def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant - - @property - def is_agent(self) -> bool: - app_model_config = self.app_model_config - if not app_model_config: - return False - if not app_model_config.agent_mode: - return False - if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: - return True - return False - + @property def deleted_tools(self) -> list: # get agent mode tools @@ -153,11 +143,6 @@ def deleted_tools(self) -> list: return deleted_tools -class ChatbotAppEngine(Enum): - NORMAL = 'normal' - WORKFLOW = 'workflow' - - class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( @@ -167,9 +152,9 @@ class AppModelConfig(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) - configs = db.Column(db.JSON, nullable=False) + provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) + configs = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) opening_statement = db.Column(db.Text) @@ -191,7 +176,6 @@ class AppModelConfig(db.Model): dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) - chatbot_app_engine = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) workflow_id = db.Column(UUID) @property @@ -301,9 +285,6 @@ def workflow(self): def to_dict(self) -> dict: return { - "provider": "", - "model_id": "", - "configs": {}, "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, @@ -327,9 +308,6 @@ def to_dict(self) -> dict: } def from_model_config_dict(self, model_config: dict): - self.provider = "" - self.model_id = "" - self.configs = {} self.opening_statement = model_config['opening_statement'] self.suggested_questions = json.dumps(model_config['suggested_questions']) self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) @@ -358,15 +336,13 @@ def from_model_config_dict(self, model_config: dict): if model_config.get('dataset_configs') else None self.file_upload = json.dumps(model_config.get('file_upload')) \ if model_config.get('file_upload') else None + self.workflow_id = model_config.get('workflow_id') return self def copy(self): new_app_model_config = AppModelConfig( id=self.id, app_id=self.app_id, - provider="", - model_id="", - configs={}, opening_statement=self.opening_statement, suggested_questions=self.suggested_questions, suggested_questions_after_answer=self.suggested_questions_after_answer, @@ -385,7 +361,8 @@ def copy(self): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload + file_upload=self.file_upload, + workflow_id=self.workflow_id ) return new_app_model_config @@ -446,12 +423,6 @@ def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant - @property - def is_agent(self) -> bool: - app = self.app - if not app: - return False - return app.is_agent class Conversation(db.Model): __tablename__ = 'conversations' diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c6f0bed008986b..ed24762dd89d74 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -21,7 +21,7 @@ from extensions.ext_database import db from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig, ChatbotAppEngine, Site +from models.model import App, AppMode, AppModelConfig, Site from models.workflow import Workflow, WorkflowType @@ -85,8 +85,6 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: new_app_model_config.chat_prompt_config = '' new_app_model_config.completion_prompt_config = '' new_app_model_config.dataset_configs = '' - new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if app_model.mode == AppMode.CHAT.value else ChatbotAppEngine.NORMAL.value new_app_model_config.workflow_id = workflow.id db.session.add(new_app_model_config) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4f7262b7d6d993..3143818d12644b 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,9 +1,10 @@ import json from datetime import datetime +from typing import Optional from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode from models.workflow import Workflow, WorkflowType from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter @@ -58,6 +59,40 @@ def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> # return draft workflow return workflow + def publish_draft_workflow(self, app_model: App, + account: Account, + draft_workflow: Optional[Workflow] = None) -> Workflow: + """ + Publish draft workflow + + :param app_model: App instance + :param account: Account instance + :param draft_workflow: Workflow instance + """ + if not draft_workflow: + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('No valid workflow found.') + + # create new workflow + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=draft_workflow.type, + version=str(datetime.utcnow()), + graph=draft_workflow.graph, + created_by=account.id + ) + + # commit db session changes + db.session.add(workflow) + db.session.commit() + + # return new workflow + return workflow + def get_default_block_configs(self) -> dict: """ Get default block configs @@ -77,11 +112,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode == AppMode.CHAT.value: - # check if chatbot app is in basic mode - if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: - raise ValueError('Chatbot app already in workflow mode') - elif app_model.mode != AppMode.COMPLETION.value: + if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') # convert to workflow From 9f42892b42cdb88e7a1c71383f6d482891ec98b1 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:23:20 +0800 Subject: [PATCH 024/160] lint fix --- api/constants/languages.py | 2 -- .../versions/cc04d0998d4d_set_model_config_column_nullable.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/api/constants/languages.py b/api/constants/languages.py index 0147dd8d705104..dd8a29eaef3944 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,6 +1,4 @@ -import json -from models.model import AppModelConfig languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN'] diff --git a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py index c302e8b5302e5c..aefbe43f148f26 100644 --- a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py +++ b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py @@ -5,8 +5,8 @@ Create Date: 2024-02-27 03:47:47.376325 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. From c13e8077ba6bd364cb7058b02ed4cac3fa692e95 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:27:46 +0800 Subject: [PATCH 025/160] fix agent app converter command --- api/commands.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/commands.py b/api/commands.py index e376d222c65329..73325620ee8b69 100644 --- a/api/commands.py +++ b/api/commands.py @@ -405,12 +405,12 @@ def convert_to_agent_apps(): click.echo('Converting app: {}'.format(app.id)) try: - app.mode = AppMode.AGENT.value + app.mode = AppMode.AGENT_CHAT.value db.session.commit() # update conversation mode to agent db.session.query(Conversation).filter(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT.value} + {Conversation.mode: AppMode.AGENT_CHAT.value} ) db.session.commit() From 84c3ec0ea71bdcef0bde6901ddf4b6e3a64f2f56 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:40:18 +0800 Subject: [PATCH 026/160] site init move to event handler --- api/controllers/console/app/app.py | 182 +++++------------- api/events/event_handlers/__init__.py | 1 + .../create_site_record_when_app_created.py | 20 ++ api/services/workflow/workflow_converter.py | 13 +- 4 files changed, 71 insertions(+), 145 deletions(-) create mode 100644 api/events/event_handlers/create_site_record_when_app_created.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4c218bef1bc2f1..4d88733d5fb141 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,5 +1,4 @@ import json -import logging from datetime import datetime from typing import cast @@ -8,29 +7,24 @@ from flask_restful import Resource, abort, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden -from constants.languages import languages from constants.model_template import default_app_templates from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, - template_list_fields, ) from libs.login import login_required -from models.model import App, AppModelConfig, Site, AppMode -from services.app_model_config_service import AppModelConfigService +from models.model import App, AppModelConfig, AppMode from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager @@ -102,95 +96,47 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - # TODO: MOVE TO IMPORT API - if args['model_config'] is not None: - # validate config - model_config_dict = args['model_config'] - - # Get provider configurations - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id) - - # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=ModelType.LLM, - only_active=True - ) - - # check if model is available - available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models] - provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}" - if provider_model not in available_models_names: - if not default_model_entity: - raise ProviderNotInitializeError( - "No Default System Reasoning Model available. Please configure " - "in the Settings -> Model Provider.") - else: - model_config_dict["model"]["provider"] = default_model_entity.provider.provider - model_config_dict["model"]["name"] = default_model_entity.model - - model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - account=current_user, - config=model_config_dict, - app_mode=args['mode'] - ) - - app = App( - enable_site=True, - enable_api=True, - is_demo=False, - api_rpm=0, - api_rph=0, - status='normal' - ) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_configuration) - else: - if 'mode' not in args or args['mode'] is None: - abort(400, message="mode is required") - - app_mode = AppMode.value_of(args['mode']) - - app_template = default_app_templates[app_mode] - - # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: - # get model provider - model_manager = ModelManager() - - # get default model instance - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None - - if model_instance: - if model_instance.model == default_model_config['model']['name']: - default_model_dict = default_model_config['model'] - else: - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - - default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} - } - else: + if 'mode' not in args or args['mode'] is None: + abort(400, message="mode is required") + + app_mode = AppMode.value_of(args['mode']) + + app_template = default_app_templates[app_mode] + + # get model config + default_model_config = app_template['model_config'] + if 'model' in default_model_config: + # get model provider + model_manager = ModelManager() + + # get default model instance + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None + + if model_instance: + if model_instance.model == default_model_config['model']['name']: default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - default_model_config['model'] = json.dumps(default_model_dict) + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] - app = App(**app_template['app']) - app_model_config = AppModelConfig(**default_model_config) + default_model_config['model'] = json.dumps(default_model_dict) + app = App(**app_template['app']) app.name = args['name'] app.mode = args['mode'] app.icon = args['icon'] @@ -200,26 +146,14 @@ def post(self): db.session.add(app) db.session.flush() + app_model_config = AppModelConfig(**default_model_config) app_model_config.app_id = app.id db.session.add(app_model_config) db.session.flush() app.app_model_config_id = app_model_config.id - account = current_user - - site = Site( - app_id=app.id, - title=app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(app) + app_was_created.send(app, account=current_user) return app, 201 @@ -262,21 +196,16 @@ def post(self): "when mode is advanced-chat or workflow") app = App( + tenant_id=current_user.current_tenant_id, + mode=app_data.get('mode'), + name=args.get("name") if args.get("name") else app_data.get('name'), + icon=args.get("icon") if args.get("icon") else app_data.get('icon'), + icon_background=args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background'), enable_site=True, - enable_api=True, - is_demo=False, - api_rpm=0, - api_rph=0, - status='normal' + enable_api=True ) - app.tenant_id = current_user.current_tenant_id - app.mode = app_data.get('mode') - app.name = args.get("name") if args.get("name") else app_data.get('name') - app.icon = args.get("icon") if args.get("icon") else app_data.get('icon') - app.icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') - db.session.add(app) db.session.commit() @@ -295,20 +224,7 @@ def post(self): app.app_model_config_id = app_model_config.id - account = current_user - - site = Site( - app_id=app.id, - title=app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(app) + app_was_created.send(app, account=current_user) return app, 201 diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 88d226d3033ab1..fdfb401bd4d334 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -2,6 +2,7 @@ from .clean_when_document_deleted import handle from .create_document_index import handle from .create_installed_app_when_app_created import handle +from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle from .generate_conversation_name_when_first_message_created import handle diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py new file mode 100644 index 00000000000000..25fba591d02af8 --- /dev/null +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -0,0 +1,20 @@ +from events.app_event import app_was_created +from extensions.ext_database import db +from models.model import Site + + +@app_was_created.connect +def handle(sender, **kwargs): + """Create site record when an app is created.""" + app = sender + account = kwargs.get('account') + site = Site( + app_id=app.id, + title=app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index ed24762dd89d74..72c6d3f719fde7 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -93,18 +93,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: new_app.app_model_config_id = new_app_model_config.id db.session.commit() - site = Site( - app_id=new_app.id, - title=new_app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(new_app) + app_was_created.send(new_app, account=account) return new_app From 8b529a3ec7f912ac4a50c6b2463efda7c8363763 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:25:39 +0800 Subject: [PATCH 027/160] refactor app api --- api/controllers/console/app/app.py | 210 ++----------- .../console/explore/recommended_app.py | 28 +- api/services/app_service.py | 281 ++++++++++++++++++ api/services/workflow/workflow_converter.py | 2 +- 4 files changed, 309 insertions(+), 212 deletions(-) create mode 100644 api/services/app_service.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4d88733d5fb141..6c0d0ca9a6ae87 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,29 +1,18 @@ -import json -from datetime import datetime -from typing import cast - -import yaml from flask_login import current_user from flask_restful import Resource, abort, inputs, marshal_with, reqparse -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import Forbidden, BadRequest -from constants.model_template import default_app_templates from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import ProviderTokenNotInitError -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created, app_was_deleted -from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, ) from libs.login import login_required +from services.app_service import AppService from models.model import App, AppModelConfig, AppMode from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager @@ -49,32 +38,9 @@ def get(self): parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() - filters = [ - App.tenant_id == current_user.current_tenant_id, - App.is_universal == False - ] - - if args['mode'] == 'workflow': - filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': - filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent': - filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': - filters.append(App.mode == AppMode.CHANNEL.value) - else: - pass - - if 'name' in args and args['name']: - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - - app_models = db.paginate( - db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False - ) + # get app list + app_service = AppService() + app_models = app_service.get_paginate_apps(current_user.current_tenant_id, args) return app_models @@ -97,63 +63,10 @@ def post(self): raise Forbidden() if 'mode' not in args or args['mode'] is None: - abort(400, message="mode is required") - - app_mode = AppMode.value_of(args['mode']) - - app_template = default_app_templates[app_mode] - - # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: - # get model provider - model_manager = ModelManager() - - # get default model instance - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None - - if model_instance: - if model_instance.model == default_model_config['model']['name']: - default_model_dict = default_model_config['model'] - else: - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - - default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} - } - else: - default_model_dict = default_model_config['model'] - - default_model_config['model'] = json.dumps(default_model_dict) - - app = App(**app_template['app']) - app.name = args['name'] - app.mode = args['mode'] - app.icon = args['icon'] - app.icon_background = args['icon_background'] - app.tenant_id = current_user.current_tenant_id - - db.session.add(app) - db.session.flush() - - app_model_config = AppModelConfig(**default_model_config) - app_model_config.app_id = app.id - db.session.add(app_model_config) - db.session.flush() - - app.app_model_config_id = app_model_config.id + raise BadRequest("mode is required") - app_was_created.send(app, account=current_user) + app_service = AppService() + app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 @@ -177,54 +90,8 @@ def post(self): parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - try: - import_data = yaml.safe_load(args['data']) - except yaml.YAMLError as e: - raise ValueError("Invalid YAML format in data argument.") - - app_data = import_data.get('app') - model_config_data = import_data.get('model_config') - workflow_graph = import_data.get('workflow_graph') - - if not app_data or not model_config_data: - raise ValueError("Missing app or model_config in data argument") - - app_mode = AppMode.value_of(app_data.get('mode')) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - if not workflow_graph: - raise ValueError("Missing workflow_graph in data argument " - "when mode is advanced-chat or workflow") - - app = App( - tenant_id=current_user.current_tenant_id, - mode=app_data.get('mode'), - name=args.get("name") if args.get("name") else app_data.get('name'), - icon=args.get("icon") if args.get("icon") else app_data.get('icon'), - icon_background=args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background'), - enable_site=True, - enable_api=True - ) - - db.session.add(app) - db.session.commit() - - if workflow_graph: - workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, current_user) - published_workflow = workflow_service.publish_draft_workflow(app, current_user, draft_workflow) - model_config_data['workflow_id'] = published_workflow.id - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - - app_was_created.send(app, account=current_user) + app_service = AppService() + app = app_service.import_app(current_user.current_tenant_id, args, current_user) return app, 201 @@ -286,13 +153,8 @@ def delete(self, app_model): if not current_user.is_admin_or_owner: raise Forbidden() - db.session.delete(app_model) - db.session.commit() - - # todo delete related data?? - # model_config, site, api_token, conversation, message, message_feedback, message_annotation - - app_was_deleted.send(app_model) + app_service = AppService() + app_service.delete_app(app_model) return {'result': 'success'}, 204 @@ -304,28 +166,10 @@ class AppExportApi(Resource): @get_app_model def get(self, app_model): """Export app""" - app_model_config = app_model.app_model_config - - export_data = { - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background - }, - "model_config": app_model_config.to_dict(), - } - - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) - else: - # get draft workflow - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_service = AppService() return { - "data": yaml.dump(export_data) + "data": app_service.export_app(app_model) } @@ -340,9 +184,9 @@ def post(self, app_model): parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() - app_model.name = args.get('name') - app_model.updated_at = datetime.utcnow() - db.session.commit() + app_service = AppService() + app_model = app_service.update_app_name(app_model, args.get('name')) + return app_model @@ -358,10 +202,8 @@ def post(self, app_model): parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - app_model.icon = args.get('icon') - app_model.icon_background = args.get('icon_background') - app_model.updated_at = datetime.utcnow() - db.session.commit() + app_service = AppService() + app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) return app_model @@ -377,12 +219,9 @@ def post(self, app_model): parser.add_argument('enable_site', type=bool, required=True, location='json') args = parser.parse_args() - if args.get('enable_site') == app_model.enable_site: - return app_model + app_service = AppService() + app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) - app_model.enable_site = args.get('enable_site') - app_model.updated_at = datetime.utcnow() - db.session.commit() return app_model @@ -397,12 +236,9 @@ def post(self, app_model): parser.add_argument('enable_api', type=bool, required=True, location='json') args = parser.parse_args() - if args.get('enable_api') == app_model.enable_api: - return app_model + app_service = AppService() + app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) - app_model.enable_api = args.get('enable_api') - app_model.updated_at = datetime.utcnow() - db.session.commit() return app_model diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 3c28980f5188de..8190f7828dc755 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,6 +1,3 @@ -import json - -import yaml from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse @@ -9,7 +6,7 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from models.model import App, RecommendedApp -from services.workflow_service import WorkflowService +from services.app_service import AppService app_fields = { 'id': fields.String, @@ -103,25 +100,8 @@ def get(self, app_id): if not app_model or not app_model.is_public: raise AppNotFoundError - app_model_config = app_model.app_model_config - - export_data = { - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background - }, - "model_config": app_model_config.to_dict(), - } - - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) - else: - # get draft workflow - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_service = AppService() + export_str = app_service.export_app(app_model) return { 'id': app_model.id, @@ -129,7 +109,7 @@ def get(self, app_id): 'icon': app_model.icon, 'icon_background': app_model.icon_background, 'mode': app_model.mode, - 'export_data': yaml.dump(export_data) + 'export_data': export_str } diff --git a/api/services/app_service.py b/api/services/app_service.py new file mode 100644 index 00000000000000..e80c720d4ccd26 --- /dev/null +++ b/api/services/app_service.py @@ -0,0 +1,281 @@ +import json +from datetime import datetime +from typing import cast + +import yaml + +from constants.model_template import default_app_templates +from core.errors.error import ProviderTokenNotInitError +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from events.app_event import app_was_created, app_was_deleted +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, AppModelConfig +from services.workflow_service import WorkflowService + + +class AppService: + def get_paginate_apps(self, tenant_id: str, args: dict) -> list[App]: + """ + Get app list with pagination + :param tenant_id: tenant id + :param args: request args + :return: + """ + filters = [ + App.tenant_id == tenant_id, + App.is_universal == False + ] + + if args['mode'] == 'workflow': + filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) + elif args['mode'] == 'chat': + filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) + elif args['mode'] == 'agent': + filters.append(App.mode == AppMode.AGENT_CHAT.value) + elif args['mode'] == 'channel': + filters.append(App.mode == AppMode.CHANNEL.value) + + if 'name' in args and args['name']: + name = args['name'][:30] + filters.append(App.name.ilike(f'%{name}%')) + + app_models = db.paginate( + db.select(App).where(*filters).order_by(App.created_at.desc()), + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return app_models + + def create_app(self, tenant_id: str, args: dict, account: Account) -> App: + """ + Create app + :param tenant_id: tenant id + :param args: request args + :param account: Account instance + """ + app_mode = AppMode.value_of(args['mode']) + app_template = default_app_templates[app_mode] + + # get model config + default_model_config = app_template['model_config'] + if 'model' in default_model_config: + # get model provider + model_manager = ModelManager() + + # get default model instance + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=account.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None + + if model_instance: + if model_instance.model == default_model_config['model']['name']: + default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] + + default_model_config['model'] = json.dumps(default_model_dict) + + app = App(**app_template['app']) + app.name = args['name'] + app.mode = args['mode'] + app.icon = args['icon'] + app.icon_background = args['icon_background'] + app.tenant_id = account.current_tenant_id + + db.session.add(app) + db.session.flush() + + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() + + app.app_model_config_id = app_model_config.id + + app_was_created.send(app, account=account) + + return app + + def import_app(self, tenant_id: str, args: dict, account: Account) -> App: + """ + Import app + :param tenant_id: tenant id + :param args: request args + :param account: Account instance + """ + try: + import_data = yaml.safe_load(args['data']) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML format in data argument.") + + app_data = import_data.get('app') + model_config_data = import_data.get('model_config') + workflow_graph = import_data.get('workflow_graph') + + if not app_data or not model_config_data: + raise ValueError("Missing app or model_config in data argument") + + app_mode = AppMode.value_of(app_data.get('mode')) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if not workflow_graph: + raise ValueError("Missing workflow_graph in data argument " + "when mode is advanced-chat or workflow") + + app = App( + tenant_id=tenant_id, + mode=app_data.get('mode'), + name=args.get("name") if args.get("name") else app_data.get('name'), + icon=args.get("icon") if args.get("icon") else app_data.get('icon'), + icon_background=args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background'), + enable_site=True, + enable_api=True + ) + + db.session.add(app) + db.session.commit() + + if workflow_graph: + workflow_service = WorkflowService() + draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, account) + published_workflow = workflow_service.publish_draft_workflow(app, account, draft_workflow) + model_config_data['workflow_id'] = published_workflow.id + + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + + app_was_created.send(app, account=account) + + return app + + def export_app(self, app: App) -> str: + """ + Export app + :param app: App instance + :return: + """ + app_model_config = app.app_model_config + + export_data = { + "app": { + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background + }, + "model_config": app_model_config.to_dict(), + } + + if app_model_config.workflow_id: + export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + else: + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app) + export_data['workflow_graph'] = json.loads(workflow.graph) + + return yaml.dump(export_data) + + def update_app_name(self, app: App, name: str) -> App: + """ + Update app name + :param app: App instance + :param name: new name + :return: App instance + """ + app.name = name + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + """ + Update app icon + :param app: App instance + :param icon: new icon + :param icon_background: new icon_background + :return: App instance + """ + app.icon = icon + app.icon_background = icon_background + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_site_status(self, app: App, enable_site: bool) -> App: + """ + Update app site status + :param app: App instance + :param enable_site: enable site status + :return: App instance + """ + if enable_site == app.enable_site: + return app + + app.enable_site = enable_site + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_api_status(self, app: App, enable_api: bool) -> App: + """ + Update app api status + :param app: App instance + :param enable_api: enable api status + :return: App instance + """ + if enable_api == app.enable_api: + return app + + app.enable_api = enable_api + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def delete_app(self, app: App) -> None: + """ + Delete app + :param app: App instance + """ + db.session.delete(app) + db.session.commit() + + app_was_deleted.send(app) + + # todo async delete related data by event + # app_model_configs, site, api_tokens, installed_apps, recommended_apps BY app + # app_annotation_hit_histories, app_annotation_settings, app_dataset_joins BY app + # workflows, workflow_runs, workflow_node_executions, workflow_app_logs BY app + # conversations, pinned_conversations, messages BY app + # message_feedbacks, message_annotations, message_chains BY message + # message_agent_thoughts, message_files, saved_messages BY message + + diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 72c6d3f719fde7..fb6cf1fd5a4fb5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -21,7 +21,7 @@ from extensions.ext_database import db from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow, WorkflowType From 4f50f113dd192e6dfd5d4164bf7fc0e0a26962fb Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:25:49 +0800 Subject: [PATCH 028/160] lint fix --- api/services/app_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index e80c720d4ccd26..f3a12a8b9cc0ce 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -7,7 +7,7 @@ from constants.model_template import default_app_templates from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db From a457faa2bf488ca7ad8dcee4b2a4103c0f3da506 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:28:40 +0800 Subject: [PATCH 029/160] trigger app_model_config_was_updated when app import --- api/services/app_service.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index f3a12a8b9cc0ce..375c1021147004 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -9,7 +9,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created, app_was_deleted +from events.app_event import app_was_created, app_was_deleted, app_model_config_was_updated from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, AppModelConfig @@ -171,6 +171,11 @@ def import_app(self, tenant_id: str, args: dict, account: Account) -> App: app_was_created.send(app, account=account) + app_model_config_was_updated.send( + app, + app_model_config=app_model_config + ) + return app def export_app(self, app: App) -> str: From 742b87df5e3be98cfa47ecff3b7ae160f0f060ff Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:29:17 +0800 Subject: [PATCH 030/160] lint fix --- api/services/app_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 375c1021147004..a83c7e6ac41025 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -9,7 +9,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created, app_was_deleted, app_model_config_was_updated +from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, AppModelConfig From 7d51d6030be5896bb3f4299cff4387d6b50255d4 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:36:42 +0800 Subject: [PATCH 031/160] remove publish workflow when app import --- api/services/app_service.py | 7 ++----- api/services/workflow_service.py | 34 ++++++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index a83c7e6ac41025..6955a6dccbd344 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -155,10 +155,9 @@ def import_app(self, tenant_id: str, args: dict, account: Account) -> App: db.session.commit() if workflow_graph: + # init draft workflow workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, account) - published_workflow = workflow_service.publish_draft_workflow(app, account, draft_workflow) - model_config_data['workflow_id'] = published_workflow.id + workflow_service.sync_draft_workflow(app, workflow_graph, account) app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(model_config_data) @@ -282,5 +281,3 @@ def delete_app(self, app: App) -> None: # conversations, pinned_conversations, messages BY app # message_feedbacks, message_annotations, message_chains BY message # message_agent_thoughts, message_files, saved_messages BY message - - diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3143818d12644b..dac88d6396a2e2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -59,11 +59,11 @@ def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> # return draft workflow return workflow - def publish_draft_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, + account: Account, + draft_workflow: Optional[Workflow] = None) -> Workflow: """ - Publish draft workflow + Publish workflow from draft :param app_model: App instance :param account: Account instance @@ -76,6 +76,8 @@ def publish_draft_workflow(self, app_model: App, if not draft_workflow: raise ValueError('No valid workflow found.') + # TODO check if the workflow is valid + # create new workflow workflow = Workflow( tenant_id=app_model.tenant_id, @@ -90,6 +92,30 @@ def publish_draft_workflow(self, app_model: App, db.session.add(workflow) db.session.commit() + app_model_config = app_model.app_model_config + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.id = None + new_app_model_config.app_id = app_model.id + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.flush() + + app_model.app_model_config_id = new_app_model_config.id + db.session.commit() + # return new workflow return workflow From 03749917f04be9ef5473ca3e72f84e62cab24c98 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 18:03:47 +0800 Subject: [PATCH 032/160] add workflow app log api --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/app/workflow.py | 36 +++++++++++ .../console/app/workflow_app_log.py | 41 ++++++++++++ api/fields/end_user_fields.py | 8 +++ api/fields/workflow_app_log_fields.py | 25 ++++++++ api/fields/workflow_fields.py | 13 ++++ api/models/__init__.py | 45 +++++++++++++- api/models/workflow.py | 20 +++++- api/services/app_service.py | 3 +- api/services/workflow_app_service.py | 62 +++++++++++++++++++ api/services/workflow_service.py | 24 ++++++- 12 files changed, 276 insertions(+), 7 deletions(-) create mode 100644 api/controllers/console/app/workflow_app_log.py create mode 100644 api/fields/end_user_fields.py create mode 100644 api/fields/workflow_app_log_fields.py create mode 100644 api/services/workflow_app_service.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 649df278ecd02c..a6f803785ab2f4 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic, workflow) + model_config, site, statistic, workflow, workflow_app_log) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 6c0d0ca9a6ae87..898fd4f7c40a75 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -40,9 +40,9 @@ def get(self): # get app list app_service = AppService() - app_models = app_service.get_paginate_apps(current_user.current_tenant_id, args) + app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) - return app_models + return app_pagination @setup_required @login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6023d0ba45a9d4..8e51ae8cbd79c8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -51,6 +51,41 @@ def post(self, app_model: App): } +class PublishedWorkflowApi(Resource): + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get published workflow + """ + # fetch published workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_published_workflow(app_model=app_model) + + # return workflow, if not found, return None + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Publish workflow + """ + workflow_service = WorkflowService() + workflow_service.publish_workflow(app_model=app_model, account=current_user) + + return { + "result": "success" + } + + + class DefaultBlockConfigApi(Resource): @setup_required @login_required @@ -88,5 +123,6 @@ def post(self, app_model: App): api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py new file mode 100644 index 00000000000000..87614d549d1d93 --- /dev/null +++ b/api/controllers/console/app/workflow_app_log.py @@ -0,0 +1,41 @@ +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.workflow_app_log_fields import workflow_app_log_pagination_fields +from libs.login import login_required +from models.model import AppMode, App +from services.workflow_app_service import WorkflowAppService + + +class WorkflowAppLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_app_log_pagination_fields) + def get(self, app_model: App): + """ + Get workflow app logs + """ + parser = reqparse.RequestParser() + parser.add_argument('keyword', type=str, location='args') + parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') + parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') + parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + args = parser.parse_args() + + # get paginate workflow app logs + workflow_app_service = WorkflowAppService() + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( + app_model=app_model, + args=args + ) + + return workflow_app_log_pagination + + +api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py new file mode 100644 index 00000000000000..ee630c12c2e9aa --- /dev/null +++ b/api/fields/end_user_fields.py @@ -0,0 +1,8 @@ +from flask_restful import fields + +simple_end_user_fields = { + 'id': fields.String, + 'type': fields.String, + 'is_anonymous': fields.Boolean, + 'session_id': fields.String, +} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py new file mode 100644 index 00000000000000..6862f0411d9f1f --- /dev/null +++ b/api/fields/workflow_app_log_fields.py @@ -0,0 +1,25 @@ +from flask_restful import fields + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from fields.workflow_fields import workflow_run_fields +from libs.helper import TimestampField + + +workflow_app_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_fields, attribute='workflow_run', allow_null=True), + "created_from": fields.String, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_at": TimestampField +} + +workflow_app_log_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index decdc0567f1934..091f2931507fba 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,3 +13,16 @@ 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_at': TimestampField } + +workflow_run_fields = { + "id": fields.String, + "version": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "finished_at": TimestampField +} \ No newline at end of file diff --git a/api/models/__init__.py b/api/models/__init__.py index 44d37d3052e8cd..47eec535428105 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1 +1,44 @@ -# -*- coding:utf-8 -*- \ No newline at end of file +from enum import Enum + + +class CreatedByRole(Enum): + """ + Enum class for createdByRole + """ + ACCOUNT = "account" + END_USER = "end_user" + + @classmethod + def value_of(cls, value: str) -> 'CreatedByRole': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for role in cls: + if role.value == value: + return role + raise ValueError(f'invalid createdByRole value {value}') + + +class CreatedFrom(Enum): + """ + Enum class for createdFrom + """ + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + + @classmethod + def value_of(cls, value: str) -> 'CreatedFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for role in cls: + if role.value == value: + return role + raise ValueError(f'invalid createdFrom value {value}') diff --git a/api/models/workflow.py b/api/models/workflow.py index 251f33b0c08d8c..41266fe9f567b0 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,6 +5,7 @@ from extensions.ext_database import db from models.account import Account +from models.model import EndUser class CreatedByRole(Enum): @@ -148,6 +149,7 @@ class WorkflowRunStatus(Enum): RUNNING = 'running' SUCCEEDED = 'succeeded' FAILED = 'failed' + STOPPED = 'stopped' @classmethod def value_of(cls, value: str) -> 'WorkflowRunStatus': @@ -184,7 +186,7 @@ class WorkflowRun(db.Model): - version (string) Version - graph (text) Workflow canvas configuration (JSON) - inputs (text) Input parameters - - status (string) Execution status, `running` / `succeeded` / `failed` + - status (string) Execution status, `running` / `succeeded` / `failed` / `stopped` - outputs (text) `optional` Output content - error (string) `optional` Error reason - elapsed_time (float) `optional` Time consumption (s) @@ -366,3 +368,19 @@ class WorkflowAppLog(db.Model): created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def workflow_run(self): + return WorkflowRun.query.get(self.workflow_run_id) + + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None diff --git a/api/services/app_service.py b/api/services/app_service.py index 6955a6dccbd344..5de87dbad5dda8 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -3,6 +3,7 @@ from typing import cast import yaml +from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates from core.errors.error import ProviderTokenNotInitError @@ -17,7 +18,7 @@ class AppService: - def get_paginate_apps(self, tenant_id: str, args: dict) -> list[App]: + def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination: """ Get app list with pagination :param tenant_id: tenant id diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py new file mode 100644 index 00000000000000..5897fcf1823533 --- /dev/null +++ b/api/services/workflow_app_service.py @@ -0,0 +1,62 @@ +from flask_sqlalchemy.pagination import Pagination +from sqlalchemy import or_, and_ + +from extensions.ext_database import db +from models import CreatedByRole +from models.model import App, EndUser +from models.workflow import WorkflowAppLog, WorkflowRunStatus, WorkflowRun + + +class WorkflowAppService: + + def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: + """ + Get paginate workflow app logs + :param app: app model + :param args: request args + :return: + """ + query = ( + db.select(WorkflowAppLog) + .where( + WorkflowAppLog.tenant_id == app_model.tenant_id, + WorkflowAppLog.app_id == app_model.id + ) + ) + + status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None + if args['keyword'] or status: + query = query.join( + WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id + ) + + if args['keyword']: + keyword_val = f"%{args['keyword'][:30]}%" + keyword_conditions = [ + WorkflowRun.inputs.ilike(keyword_val), + WorkflowRun.outputs.ilike(keyword_val), + # filter keyword by end user session id if created by end user role + and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_val)) + ] + + query = query.outerjoin( + EndUser, + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + ).filter(or_(*keyword_conditions)) + + if status: + # join with workflow_run and filter by status + query = query.filter( + WorkflowRun.status == status.value + ) + + query = query.order_by(WorkflowAppLog.created_at.desc()) + + pagination = db.paginate( + query, + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return pagination diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index dac88d6396a2e2..ae6e4c46d3a750 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,7 +15,7 @@ class WorkflowService: Workflow Service """ - def get_draft_workflow(self, app_model: App) -> Workflow: + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow """ @@ -29,6 +29,26 @@ def get_draft_workflow(self, app_model: App) -> Workflow: # return draft workflow return workflow + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get published workflow + """ + app_model_config = app_model.app_model_config + + if not app_model_config.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model_config.workflow_id + ).first() + + # return published workflow + return workflow + + def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: """ Sync draft workflow @@ -116,6 +136,8 @@ def publish_workflow(self, app_model: App, app_model.app_model_config_id = new_app_model_config.id db.session.commit() + # TODO update app related datasets + # return new workflow return workflow From bf4a5f6b33f8516bc0392e3c2d07284393c2914f Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 18:04:01 +0800 Subject: [PATCH 033/160] lint fix --- api/controllers/console/app/workflow_app_log.py | 2 +- api/fields/workflow_app_log_fields.py | 1 - api/services/workflow_app_service.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 87614d549d1d93..6d1709ed8e65d9 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -7,7 +7,7 @@ from controllers.console.wraps import account_initialization_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required -from models.model import AppMode, App +from models.model import App, AppMode from services.workflow_app_service import WorkflowAppService diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 6862f0411d9f1f..8f3998d90ab1da 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -5,7 +5,6 @@ from fields.workflow_fields import workflow_run_fields from libs.helper import TimestampField - workflow_app_log_partial_fields = { "id": fields.String, "workflow_run": fields.Nested(workflow_run_fields, attribute='workflow_run', allow_null=True), diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 5897fcf1823533..047678837509e2 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -1,10 +1,10 @@ from flask_sqlalchemy.pagination import Pagination -from sqlalchemy import or_, and_ +from sqlalchemy import and_, or_ from extensions.ext_database import db from models import CreatedByRole from models.model import App, EndUser -from models.workflow import WorkflowAppLog, WorkflowRunStatus, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus class WorkflowAppService: From 20cf075b2dacc54fc3d5ee713d3b94850f0a8db2 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 21:39:13 +0800 Subject: [PATCH 034/160] add workflow runs & workflow node executions api --- api/controllers/console/app/workflow.py | 60 +++++++++++- api/controllers/console/app/workflow_run.py | 80 ++++++++++++++++ api/fields/conversation_fields.py | 1 + api/fields/workflow_app_log_fields.py | 4 +- api/fields/workflow_fields.py | 13 --- api/fields/workflow_run_fields.py | 92 +++++++++++++++++++ .../versions/b289e2408ee2_add_workflow.py | 2 +- api/models/workflow.py | 45 ++++++++- api/services/workflow_run_service.py | 89 ++++++++++++++++++ 9 files changed, 365 insertions(+), 21 deletions(-) create mode 100644 api/controllers/console/app/workflow_run.py create mode 100644 api/fields/workflow_run_fields.py create mode 100644 api/services/workflow_run_service.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8e51ae8cbd79c8..4fcf8daf6ec64e 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -51,6 +51,62 @@ def post(self, app_model: App): } +class DraftWorkflowRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Run draft workflow + """ + # TODO + workflow_service = WorkflowService() + workflow_service.run_draft_workflow(app_model=app_model, account=current_user) + + # TODO + return { + "result": "success" + } + + +class WorkflowTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, task_id: str): + """ + Stop workflow task + """ + # TODO + workflow_service = WorkflowService() + workflow_service.stop_workflow_task(app_model=app_model, task_id=task_id, account=current_user) + + return { + "result": "success" + } + + +class DraftWorkflowNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow node + """ + # TODO + workflow_service = WorkflowService() + workflow_service.run_draft_workflow_node(app_model=app_model, node_id=node_id, account=current_user) + + # TODO + return { + "result": "success" + } + + class PublishedWorkflowApi(Resource): @setup_required @@ -85,7 +141,6 @@ def post(self, app_model: App): } - class DefaultBlockConfigApi(Resource): @setup_required @login_required @@ -123,6 +178,9 @@ def post(self, app_model: App): api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') +api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') +api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py new file mode 100644 index 00000000000000..38e3d4d837dde9 --- /dev/null +++ b/api/controllers/console/app/workflow_run.py @@ -0,0 +1,80 @@ +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.workflow_run_fields import workflow_run_detail_fields, workflow_run_pagination_fields, \ + workflow_run_node_execution_list_fields +from libs.helper import uuid_value +from libs.login import login_required +from models.model import App, AppMode +from services.workflow_run_service import WorkflowRunService + + +class WorkflowRunListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_pagination_fields) + def get(self, app_model: App): + """ + Get workflow run list + """ + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_paginate_workflow_runs( + app_model=app_model, + args=args + ) + + return result + + +class WorkflowRunDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_detail_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + + return workflow_run + + +class WorkflowRunNodeExecutionListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_list_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run node execution list + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) + + return { + 'data': node_executions + } + + +api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') +api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') +api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index afa486f1cdfe6c..747b0b86abf3ef 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -66,6 +66,7 @@ def format(self, value): 'from_end_user_id': fields.String, 'from_account_id': fields.String, 'feedbacks': fields.List(fields.Nested(feedback_fields)), + 'workflow_run_id': fields.String, 'annotation': fields.Nested(annotation_fields, allow_null=True), 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), 'created_at': TimestampField, diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 8f3998d90ab1da..e230c159fba59a 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -2,12 +2,12 @@ from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields -from fields.workflow_fields import workflow_run_fields +from fields.workflow_run_fields import workflow_run_for_log_fields from libs.helper import TimestampField workflow_app_log_partial_fields = { "id": fields.String, - "workflow_run": fields.Nested(workflow_run_fields, attribute='workflow_run', allow_null=True), + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), "created_from": fields.String, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 091f2931507fba..decdc0567f1934 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,16 +13,3 @@ 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_at': TimestampField } - -workflow_run_fields = { - "id": fields.String, - "version": fields.String, - "status": fields.String, - "error": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, - "total_steps": fields.Integer, - "finished_at": TimestampField -} \ No newline at end of file diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py new file mode 100644 index 00000000000000..37751bc70f9058 --- /dev/null +++ b/api/fields/workflow_run_fields.py @@ -0,0 +1,92 @@ +from flask_restful import fields + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +workflow_run_for_log_fields = { + "id": fields.String, + "version": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_for_list_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "graph": fields.String, + "inputs": fields.String, + "status": fields.String, + "outputs": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='items') +} + +workflow_run_detail_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "graph": fields.String, + "inputs": fields.String, + "status": fields.String, + "outputs": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_node_execution_fields = { + "id": fields.String, + "index": fields.Integer, + "predecessor_node_id": fields.String, + "node_id": fields.String, + "node_type": fields.String, + "title": fields.String, + "inputs": fields.String, + "process_data": fields.String, + "outputs": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "execution_metadata": fields.String, + "created_at": TimestampField, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "finished_at": TimestampField +} + +workflow_run_node_execution_list_fields = { + 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), +} diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 7255b4b5fa6ba3..5f7ddc7d688f28 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -88,7 +88,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_run_pkey') ) with op.batch_alter_table('workflow_runs', schema=None) as batch_op: - batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from'], unique=False) + batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'triggered_from'], unique=False) op.create_table('workflows', sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), diff --git a/api/models/workflow.py b/api/models/workflow.py index 41266fe9f567b0..7ea342cda782b2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -208,7 +208,7 @@ class WorkflowRun(db.Model): __tablename__ = 'workflow_runs' __table_args__ = ( db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), - db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from'), + db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) @@ -236,11 +236,36 @@ class WorkflowRun(db.Model): @property def created_by_account(self): - return Account.query.get(self.created_by) + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None @property - def updated_by_account(self): - return Account.query.get(self.updated_by) + def created_by_end_user(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None + + +class WorkflowNodeExecutionTriggeredFrom(Enum): + """ + Workflow Node Execution Triggered From Enum + """ + SINGLE_STEP = 'single-step' + WORKFLOW_RUN = 'workflow-run' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow node execution triggered from value {value}') class WorkflowNodeExecution(db.Model): @@ -323,6 +348,18 @@ class WorkflowNodeExecution(db.Model): created_by = db.Column(UUID, nullable=False) finished_at = db.Column(db.DateTime) + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None + class WorkflowAppLog(db.Model): """ diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py new file mode 100644 index 00000000000000..9c898f10fbf9aa --- /dev/null +++ b/api/services/workflow_run_service.py @@ -0,0 +1,89 @@ +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.model import App +from models.workflow import WorkflowRun, WorkflowRunTriggeredFrom, WorkflowNodeExecution, \ + WorkflowNodeExecutionTriggeredFrom + + +class WorkflowRunService: + def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get('limit', 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + ) + + if args.get('last_id'): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get('last_id'), + ).first() + + if not last_workflow_run: + raise ValueError('Last workflow run not exists') + + conversations = base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id + ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + else: + conversations = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(conversations) == limit: + current_page_first_conversation = conversations[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_conversation.created_at, + WorkflowRun.id != current_page_first_conversation.id + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination( + data=conversations, + limit=limit, + has_more=has_more + ) + + def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ).first() + + return workflow_run + + def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_workflow_run(app_model, run_id) + + if not workflow_run: + return [] + + node_executions = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ).order_by(WorkflowNodeExecution.index.desc()).all() + + return node_executions From 124aa9db08f90a3fb8900dfed35ce1f018678520 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 21:39:20 +0800 Subject: [PATCH 035/160] lint fix --- api/controllers/console/app/workflow_run.py | 7 +++++-- api/services/workflow_run_service.py | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 38e3d4d837dde9..8a4c0492a1551f 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -5,8 +5,11 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from fields.workflow_run_fields import workflow_run_detail_fields, workflow_run_pagination_fields, \ - workflow_run_node_execution_list_fields +from fields.workflow_run_fields import ( + workflow_run_detail_fields, + workflow_run_node_execution_list_fields, + workflow_run_pagination_fields, +) from libs.helper import uuid_value from libs.login import login_required from models.model import App, AppMode diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 9c898f10fbf9aa..70ce1f2ce0406e 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,8 +1,12 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.model import App -from models.workflow import WorkflowRun, WorkflowRunTriggeredFrom, WorkflowNodeExecution, \ - WorkflowNodeExecutionTriggeredFrom +from models.workflow import ( + WorkflowNodeExecution, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunTriggeredFrom, +) class WorkflowRunService: From 7724d010b6e4e025e60e135ec85963928fc146c1 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 16:27:41 +0800 Subject: [PATCH 036/160] add app description add update app api --- api/controllers/console/app/app.py | 23 ++++++++++++- api/fields/app_fields.py | 4 +++ .../f9107f83abab_add_desc_for_apps.py | 32 +++++++++++++++++++ api/models/model.py | 4 ++- api/models/workflow.py | 4 ++- api/services/app_service.py | 20 +++++++++++- 6 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 api/migrations/versions/f9107f83abab_add_desc_for_apps.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 898fd4f7c40a75..98636fa95f3e03 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from flask_restful import Resource, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden, BadRequest from controllers.console import api @@ -53,6 +53,7 @@ def post(self): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('description', type=str, location='json') parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') @@ -86,6 +87,7 @@ def post(self): parser = reqparse.RequestParser() parser.add_argument('data', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, location='json') + parser.add_argument('description', type=str, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() @@ -144,6 +146,25 @@ def get(self, app_model): return app_model + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_detail_fields_with_site) + def put(self, app_model): + """Update app""" + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + app_service = AppService() + app_model = app_service.update_app(app_model, args) + + return app_model + @setup_required @login_required @account_initialization_required diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 75b68d24fcd449..69ab1d3e3e8c3f 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -5,6 +5,7 @@ app_detail_kernel_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -41,6 +42,7 @@ app_detail_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -62,6 +64,7 @@ app_partial_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -109,6 +112,7 @@ app_detail_fields_with_site = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, diff --git a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py new file mode 100644 index 00000000000000..88d77bb32018af --- /dev/null +++ b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py @@ -0,0 +1,32 @@ +"""add desc for apps + +Revision ID: f9107f83abab +Revises: cc04d0998d4d +Create Date: 2024-02-28 08:16:14.090481 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f9107f83abab' +down_revision = 'cc04d0998d4d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('description') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 713d8da5775bf7..8d286d34827f71 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,7 +14,6 @@ from libs.helper import generate_string from .account import Account, Tenant -from .workflow import Workflow, WorkflowRun class DifySetup(db.Model): @@ -59,6 +58,7 @@ class App(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(UUID, nullable=False) name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) @@ -279,6 +279,7 @@ def file_upload_dict(self) -> dict: @property def workflow(self): if self.workflow_id: + from api.models.workflow import Workflow return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return None @@ -692,6 +693,7 @@ def files(self): @property def workflow_run(self): if self.workflow_run_id: + from api.models.workflow import WorkflowRun return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None diff --git a/api/models/workflow.py b/api/models/workflow.py index 7ea342cda782b2..316d3e623e2af2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,7 +5,6 @@ from extensions.ext_database import db from models.account import Account -from models.model import EndUser class CreatedByRole(Enum): @@ -242,6 +241,7 @@ def created_by_account(self): @property def created_by_end_user(self): + from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None @@ -356,6 +356,7 @@ def created_by_account(self): @property def created_by_end_user(self): + from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None @@ -418,6 +419,7 @@ def created_by_account(self): @property def created_by_end_user(self): + from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None diff --git a/api/services/app_service.py b/api/services/app_service.py index 5de87dbad5dda8..2e534eae158130 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -97,10 +97,11 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: app = App(**app_template['app']) app.name = args['name'] + app.description = args.get('description', '') app.mode = args['mode'] app.icon = args['icon'] app.icon_background = args['icon_background'] - app.tenant_id = account.current_tenant_id + app.tenant_id = tenant_id db.session.add(app) db.session.flush() @@ -145,6 +146,7 @@ def import_app(self, tenant_id: str, args: dict, account: Account) -> App: tenant_id=tenant_id, mode=app_data.get('mode'), name=args.get("name") if args.get("name") else app_data.get('name'), + description=args.get("description") if args.get("description") else app_data.get('description', ''), icon=args.get("icon") if args.get("icon") else app_data.get('icon'), icon_background=args.get("icon_background") if args.get("icon_background") \ else app_data.get('icon_background'), @@ -205,6 +207,22 @@ def export_app(self, app: App) -> str: return yaml.dump(export_data) + def update_app(self, app: App, args: dict) -> App: + """ + Update app + :param app: App instance + :param args: request args + :return: App instance + """ + app.name = args.get('name') + app.description = args.get('description', '') + app.icon = args.get('icon') + app.icon_background = args.get('icon_background') + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + def update_app_name(self, app: App, name: str) -> App: """ Update app name From 11337e51c54ce2574dbde767337450567804e18d Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 16:27:49 +0800 Subject: [PATCH 037/160] lint fix --- api/migrations/versions/f9107f83abab_add_desc_for_apps.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py index 88d77bb32018af..3e5ae0d67d7e58 100644 --- a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py +++ b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py @@ -5,9 +5,8 @@ Create Date: 2024-02-28 08:16:14.090481 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'f9107f83abab' From 022b7d5dd442621cbb7044df2b7fee6ad2c4bbbe Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 18:24:49 +0800 Subject: [PATCH 038/160] optimize default model exceptions --- api/services/app_service.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 2e534eae158130..298cd650df8f8f 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,4 +1,5 @@ import json +import logging from datetime import datetime from typing import cast @@ -6,7 +7,7 @@ from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates -from core.errors.error import ProviderTokenNotInitError +from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -74,7 +75,10 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) - except ProviderTokenNotInitError: + except (ProviderTokenNotInitError, LLMBadRequestError): + model_instance = None + except Exception as e: + logging.exception(e) model_instance = None if model_instance: From dd70aeff247be188c834e8af06efab3c0c0e61be Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 18:27:16 +0800 Subject: [PATCH 039/160] lint fix --- api/services/app_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 298cd650df8f8f..374727d2d42495 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -7,7 +7,7 @@ from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates -from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel From 77618823a5c1da589f9d32732d3b8ef0b7907b83 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 22:16:24 +0800 Subject: [PATCH 040/160] add features update api refactor app model config validation --- api/controllers/console/app/model_config.py | 43 +- api/core/apps/__init__.py | 0 .../apps/app_config_validators/__init__.py | 0 .../advanced_chat_app.py | 54 ++ .../app_config_validators/agent_chat_app.py | 82 +++ .../apps/app_config_validators/chat_app.py | 82 +++ .../app_config_validators/completion_app.py | 67 +++ .../app_config_validators/workflow_app.py | 34 ++ api/core/apps/config_validators/__init__.py | 0 api/core/apps/config_validators/agent.py | 82 +++ api/core/apps/config_validators/dataset.py | 141 +++++ .../config_validators/external_data_tools.py | 40 ++ .../apps/config_validators/file_upload.py | 38 ++ api/core/apps/config_validators/model.py | 83 +++ api/core/apps/config_validators/moderation.py | 36 ++ .../apps/config_validators/more_like_this.py | 26 + .../config_validators/opening_statement.py | 29 + api/core/apps/config_validators/prompt.py | 87 +++ .../config_validators/retriever_resource.py | 26 + .../apps/config_validators/speech_to_text.py | 26 + .../config_validators/suggested_questions.py | 26 + .../apps/config_validators/text_to_speech.py | 30 + .../apps/config_validators/user_input_form.py | 62 ++ api/services/app_model_config_service.py | 539 +----------------- api/services/completion_service.py | 11 +- api/services/workflow_service.py | 2 +- 26 files changed, 1115 insertions(+), 531 deletions(-) create mode 100644 api/core/apps/__init__.py create mode 100644 api/core/apps/app_config_validators/__init__.py create mode 100644 api/core/apps/app_config_validators/advanced_chat_app.py create mode 100644 api/core/apps/app_config_validators/agent_chat_app.py create mode 100644 api/core/apps/app_config_validators/chat_app.py create mode 100644 api/core/apps/app_config_validators/completion_app.py create mode 100644 api/core/apps/app_config_validators/workflow_app.py create mode 100644 api/core/apps/config_validators/__init__.py create mode 100644 api/core/apps/config_validators/agent.py create mode 100644 api/core/apps/config_validators/dataset.py create mode 100644 api/core/apps/config_validators/external_data_tools.py create mode 100644 api/core/apps/config_validators/file_upload.py create mode 100644 api/core/apps/config_validators/model.py create mode 100644 api/core/apps/config_validators/moderation.py create mode 100644 api/core/apps/config_validators/more_like_this.py create mode 100644 api/core/apps/config_validators/opening_statement.py create mode 100644 api/core/apps/config_validators/prompt.py create mode 100644 api/core/apps/config_validators/retriever_resource.py create mode 100644 api/core/apps/config_validators/speech_to_text.py create mode 100644 api/core/apps/config_validators/suggested_questions.py create mode 100644 api/core/apps/config_validators/text_to_speech.py create mode 100644 api/core/apps/config_validators/user_input_form.py diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 0f8bc28f6fe14f..0ae9f5e5464080 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -2,7 +2,7 @@ from flask import request from flask_login import current_user -from flask_restful import Resource +from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -14,7 +14,7 @@ from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required -from models.model import AppModelConfig +from models.model import AppModelConfig, AppMode from services.app_model_config_service import AppModelConfigService @@ -23,15 +23,14 @@ class ModelConfigResource(Resource): @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, - account=current_user, config=request.json, - app_mode=app_model.mode + app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( @@ -138,4 +137,38 @@ def post(self, app_model): return {'result': 'success'} +class FeaturesResource(Resource): + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def put(self, app_model): + """Get app features""" + parser = reqparse.RequestParser() + parser.add_argument('features', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + model_configuration = AppModelConfigService.validate_features( + tenant_id=current_user.current_tenant_id, + config=args.get('features'), + app_mode=AppMode.value_of(app_model.mode) + ) + + # update config + app_model_config = app_model.app_model_config + app_model_config.from_model_config_dict(model_configuration) + db.session.commit() + + app_model_config_was_updated.send( + app_model, + app_model_config=app_model_config + ) + + return { + 'result': 'success' + } + + api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(FeaturesResource, '/apps//features') diff --git a/api/core/apps/__init__.py b/api/core/apps/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/apps/app_config_validators/__init__.py b/api/core/apps/app_config_validators/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/apps/app_config_validators/advanced_chat_app.py b/api/core/apps/app_config_validators/advanced_chat_app.py new file mode 100644 index 00000000000000..dc7664b844f68c --- /dev/null +++ b/api/core/apps/app_config_validators/advanced_chat_app.py @@ -0,0 +1,54 @@ +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.opening_statement import OpeningStatementValidator +from core.apps.config_validators.retriever_resource import RetrieverResourceValidator +from core.apps.config_validators.speech_to_text import SpeechToTextValidator +from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator + + +class AdvancedChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for advanced chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/agent_chat_app.py b/api/core/apps/app_config_validators/agent_chat_app.py new file mode 100644 index 00000000000000..d507fae685613b --- /dev/null +++ b/api/core/apps/app_config_validators/agent_chat_app.py @@ -0,0 +1,82 @@ +from core.apps.config_validators.agent import AgentValidator +from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.model import ModelValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.opening_statement import OpeningStatementValidator +from core.apps.config_validators.prompt import PromptValidator +from core.apps.config_validators.retriever_resource import RetrieverResourceValidator +from core.apps.config_validators.speech_to_text import SpeechToTextValidator +from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.apps.config_validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +class AgentChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/chat_app.py b/api/core/apps/app_config_validators/chat_app.py new file mode 100644 index 00000000000000..83c792e610c386 --- /dev/null +++ b/api/core/apps/app_config_validators/chat_app.py @@ -0,0 +1,82 @@ +from core.apps.config_validators.dataset import DatasetValidator +from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.model import ModelValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.opening_statement import OpeningStatementValidator +from core.apps.config_validators.prompt import PromptValidator +from core.apps.config_validators.retriever_resource import RetrieverResourceValidator +from core.apps.config_validators.speech_to_text import SpeechToTextValidator +from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.apps.config_validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +class ChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/completion_app.py b/api/core/apps/app_config_validators/completion_app.py new file mode 100644 index 00000000000000..00371f8d05fadd --- /dev/null +++ b/api/core/apps/app_config_validators/completion_app.py @@ -0,0 +1,67 @@ +from core.apps.config_validators.dataset import DatasetValidator +from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.model import ModelValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.more_like_this import MoreLikeThisValidator +from core.apps.config_validators.prompt import PromptValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.apps.config_validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +class CompletionAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for completion app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.COMPLETION + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # more_like_this + config, current_related_config_keys = MoreLikeThisValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/workflow_app.py b/api/core/apps/app_config_validators/workflow_app.py new file mode 100644 index 00000000000000..545d3d79a330e5 --- /dev/null +++ b/api/core/apps/app_config_validators/workflow_app.py @@ -0,0 +1,34 @@ +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator + + +class WorkflowAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for workflow app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/config_validators/__init__.py b/api/core/apps/config_validators/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py new file mode 100644 index 00000000000000..69f93380801b51 --- /dev/null +++ b/api/core/apps/config_validators/agent.py @@ -0,0 +1,82 @@ +import uuid +from typing import Tuple + +from core.agent.agent_executor import PlanningStrategy +from core.apps.config_validators.dataset import DatasetValidator + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/apps/config_validators/dataset.py b/api/core/apps/config_validators/dataset.py new file mode 100644 index 00000000000000..32db038c21a0d1 --- /dev/null +++ b/api/core/apps/config_validators/dataset.py @@ -0,0 +1,141 @@ +import uuid +from typing import Tuple + +from core.agent.agent_executor import PlanningStrategy +from models.model import AppMode +from services.dataset_service import DatasetService + + +class DatasetValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for dataset feature + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) + + # dataset_configs + if not config.get("dataset_configs"): + config["dataset_configs"] = {'retrieval_model': 'single'} + + if not config["dataset_configs"].get("datasets"): + config["dataset_configs"]["datasets"] = { + "strategy": "router", + "datasets": [] + } + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + if config["dataset_configs"]['retrieval_model'] == 'multiple': + if not config["dataset_configs"]['reranking_model']: + raise ValueError("reranking_model has not been set") + if not isinstance(config["dataset_configs"]['reranking_model'], dict): + raise ValueError("reranking_model must be of object type") + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets") + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] + + @classmethod + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: + """ + Extract dataset config for legacy compatibility + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + # enabled + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + # tools + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + # strategy + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + has_datasets = False + if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key == "dataset": + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not cls.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + + has_datasets = True + + need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config + + @classmethod + def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool: + # verify if the dataset ID exists + dataset = DatasetService.get_dataset(dataset_id) + + if not dataset: + return False + + if dataset.tenant_id != tenant_id: + return False + + return True diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/apps/config_validators/external_data_tools.py new file mode 100644 index 00000000000000..5412366a897c3c --- /dev/null +++ b/api/core/apps/config_validators/external_data_tools.py @@ -0,0 +1,40 @@ +from typing import Tuple + +from core.external_data_tool.factory import ExternalDataToolFactory + + +class ExternalDataToolsValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for external data fetch feature + + :param tenant_id: workspace id + :param config: app model config args + """ + if not config.get("external_data_tools"): + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + typ = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["external_data_tools"] diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/apps/config_validators/file_upload.py new file mode 100644 index 00000000000000..f9adbfdf7df3a1 --- /dev/null +++ b/api/core/apps/config_validators/file_upload.py @@ -0,0 +1,38 @@ +from typing import Tuple + + +class FileUploadValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for file upload feature + + :param config: app model config args + """ + if not config.get("file_upload"): + config["file_upload"] = {} + + if not isinstance(config["file_upload"], dict): + raise ValueError("file_upload must be of dict type") + + # check image config + if not config["file_upload"].get("image"): + config["file_upload"]["image"] = {"enabled": False} + + if config['file_upload']['image']['enabled']: + number_limits = config['file_upload']['image']['number_limits'] + if number_limits < 1 or number_limits > 6: + raise ValueError("number_limits must be in [1, 6]") + + detail = config['file_upload']['image']['detail'] + if detail not in ['high', 'low']: + raise ValueError("detail must be in ['high', 'low']") + + transfer_methods = config['file_upload']['image']['transfer_methods'] + if not isinstance(transfer_methods, list): + raise ValueError("transfer_methods must be of list type") + for method in transfer_methods: + if method not in ['remote_url', 'local_file']: + raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + + return config, ["file_upload"] diff --git a/api/core/apps/config_validators/model.py b/api/core/apps/config_validators/model.py new file mode 100644 index 00000000000000..091eec4683a37f --- /dev/null +++ b/api/core/apps/config_validators/model.py @@ -0,0 +1,83 @@ +from typing import Tuple + +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager + + +class ModelValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for model config + + :param tenant_id: tenant id + :param config: app model config args + """ + if 'model' not in config: + raise ValueError("model is required") + + if not isinstance(config["model"], dict): + raise ValueError("model must be of object type") + + # model.provider + provider_entities = model_provider_factory.get_providers() + model_provider_names = [provider.provider for provider in provider_entities] + if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") + + # model.name + if 'name' not in config["model"]: + raise ValueError("model.name is required") + + provider_manager = ProviderManager() + models = provider_manager.get_configurations(tenant_id).get_models( + provider=config["model"]["provider"], + model_type=ModelType.LLM + ) + + if not models: + raise ValueError("model.name must be in the specified model list") + + model_ids = [m.model for m in models] + if config["model"]["name"] not in model_ids: + raise ValueError("model.name must be in the specified model list") + + model_mode = None + for model in models: + if model.model == config["model"]["name"]: + model_mode = model.model_properties.get(ModelPropertyKey.MODE) + break + + # model.mode + if model_mode: + config['model']["mode"] = model_mode + else: + config['model']["mode"] = "completion" + + # model.completion_params + if 'completion_params' not in config["model"]: + raise ValueError("model.completion_params is required") + + config["model"]["completion_params"] = cls.validate_model_completion_params( + config["model"]["completion_params"] + ) + + return config, ["model"] + + @classmethod + def validate_model_completion_params(cls, cp: dict) -> dict: + # model.completion_params + if not isinstance(cp, dict): + raise ValueError("model.completion_params must be of object type") + + # stop + if 'stop' not in cp: + cp["stop"] = [] + elif not isinstance(cp["stop"], list): + raise ValueError("stop in model.completion_params must be of list type") + + if len(cp["stop"]) > 4: + raise ValueError("stop sequences must be less than 4") + + return cp diff --git a/api/core/apps/config_validators/moderation.py b/api/core/apps/config_validators/moderation.py new file mode 100644 index 00000000000000..1962f87aa9b181 --- /dev/null +++ b/api/core/apps/config_validators/moderation.py @@ -0,0 +1,36 @@ +import logging +from typing import Tuple + +from core.moderation.factory import ModerationFactory + +logger = logging.getLogger(__name__) + + +class ModerationValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id, config: dict) -> Tuple[dict, list[str]]: + if not config.get("sensitive_word_avoidance"): + config["sensitive_word_avoidance"] = { + "enabled": False + } + + if not isinstance(config["sensitive_word_avoidance"], dict): + raise ValueError("sensitive_word_avoidance must be of dict type") + + if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: + config["sensitive_word_avoidance"]["enabled"] = False + + if config["sensitive_word_avoidance"]["enabled"]: + if not config["sensitive_word_avoidance"].get("type"): + raise ValueError("sensitive_word_avoidance.type is required") + + typ = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["config"] + + ModerationFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["sensitive_word_avoidance"] diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/apps/config_validators/more_like_this.py new file mode 100644 index 00000000000000..60dc4a056255fa --- /dev/null +++ b/api/core/apps/config_validators/more_like_this.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class MoreLikeThisValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for more like this feature + + :param config: app model config args + """ + if not config.get("more_like_this"): + config["more_like_this"] = { + "enabled": False + } + + if not isinstance(config["more_like_this"], dict): + raise ValueError("more_like_this must be of dict type") + + if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: + config["more_like_this"]["enabled"] = False + + if not isinstance(config["more_like_this"]["enabled"], bool): + raise ValueError("enabled in more_like_this must be of boolean type") + + return config, ["more_like_this"] diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/apps/config_validators/opening_statement.py new file mode 100644 index 00000000000000..3f69e0e9469430 --- /dev/null +++ b/api/core/apps/config_validators/opening_statement.py @@ -0,0 +1,29 @@ +from typing import Tuple + + +class OpeningStatementValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for opening statement feature + + :param config: app model config args + """ + if not config.get("opening_statement"): + config["opening_statement"] = "" + + if not isinstance(config["opening_statement"], str): + raise ValueError("opening_statement must be of string type") + + # suggested_questions + if not config.get("suggested_questions"): + config["suggested_questions"] = [] + + if not isinstance(config["suggested_questions"], list): + raise ValueError("suggested_questions must be of list type") + + for question in config["suggested_questions"]: + if not isinstance(question, str): + raise ValueError("Elements in suggested_questions list must be of string type") + + return config, ["opening_statement", "suggested_questions"] diff --git a/api/core/apps/config_validators/prompt.py b/api/core/apps/config_validators/prompt.py new file mode 100644 index 00000000000000..815706b10b9b3e --- /dev/null +++ b/api/core/apps/config_validators/prompt.py @@ -0,0 +1,87 @@ +from typing import Tuple + +from core.entities.application_entities import PromptTemplateEntity +from core.prompt.simple_prompt_transform import ModelMode +from models.model import AppMode + + +class PromptValidator: + @classmethod + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + """ + Validate pre_prompt and set defaults for prompt feature + depending on the config['model'] + + :param app_mode: app mode + :param config: app model config args + """ + if not config.get("prompt_type"): + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + + prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + if config['prompt_type'] not in prompt_type_vals: + raise ValueError(f"prompt_type must be in {prompt_type_vals}") + + # chat_prompt_config + if not config.get("chat_prompt_config"): + config["chat_prompt_config"] = {} + + if not isinstance(config["chat_prompt_config"], dict): + raise ValueError("chat_prompt_config must be of object type") + + # completion_prompt_config + if not config.get("completion_prompt_config"): + config["completion_prompt_config"] = {} + + if not isinstance(config["completion_prompt_config"], dict): + raise ValueError("completion_prompt_config must be of object type") + + if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config['chat_prompt_config'] and not config['completion_prompt_config']: + raise ValueError("chat_prompt_config or completion_prompt_config is required " + "when prompt_type is advanced") + + model_mode_vals = [mode.value for mode in ModelMode] + if config['model']["mode"] not in model_mode_vals: + raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") + + if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: + user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] + assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + + if not user_prefix: + config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + + if not assistant_prefix: + config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + + if config['model']["mode"] == ModelMode.CHAT.value: + prompt_list = config['chat_prompt_config']['prompt'] + + if len(prompt_list) > 10: + raise ValueError("prompt messages must be less than 10") + else: + # pre_prompt, for simple mode + if not config.get("pre_prompt"): + config["pre_prompt"] = "" + + if not isinstance(config["pre_prompt"], str): + raise ValueError("pre_prompt must be of string type") + + return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] + + @classmethod + def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: + """ + Validate post_prompt and set defaults for prompt feature + + :param config: app model config args + """ + # post_prompt + if not config.get("post_prompt"): + config["post_prompt"] = "" + + if not isinstance(config["post_prompt"], str): + raise ValueError("post_prompt must be of string type") + + return config \ No newline at end of file diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/apps/config_validators/retriever_resource.py new file mode 100644 index 00000000000000..a8bcd60abef5c9 --- /dev/null +++ b/api/core/apps/config_validators/retriever_resource.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class RetrieverResourceValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for retriever resource feature + + :param config: app model config args + """ + if not config.get("retriever_resource"): + config["retriever_resource"] = { + "enabled": False + } + + if not isinstance(config["retriever_resource"], dict): + raise ValueError("retriever_resource must be of dict type") + + if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: + config["retriever_resource"]["enabled"] = False + + if not isinstance(config["retriever_resource"]["enabled"], bool): + raise ValueError("enabled in retriever_resource must be of boolean type") + + return config, ["retriever_resource"] diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/apps/config_validators/speech_to_text.py new file mode 100644 index 00000000000000..577bef0e59f8fb --- /dev/null +++ b/api/core/apps/config_validators/speech_to_text.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class SpeechToTextValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for speech to text feature + + :param config: app model config args + """ + if not config.get("speech_to_text"): + config["speech_to_text"] = { + "enabled": False + } + + if not isinstance(config["speech_to_text"], dict): + raise ValueError("speech_to_text must be of dict type") + + if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: + config["speech_to_text"]["enabled"] = False + + if not isinstance(config["speech_to_text"]["enabled"], bool): + raise ValueError("enabled in speech_to_text must be of boolean type") + + return config, ["speech_to_text"] diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/apps/config_validators/suggested_questions.py new file mode 100644 index 00000000000000..938b66bb6ecb38 --- /dev/null +++ b/api/core/apps/config_validators/suggested_questions.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class SuggestedQuestionsValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for suggested questions feature + + :param config: app model config args + """ + if not config.get("suggested_questions_after_answer"): + config["suggested_questions_after_answer"] = { + "enabled": False + } + + if not isinstance(config["suggested_questions_after_answer"], dict): + raise ValueError("suggested_questions_after_answer must be of dict type") + + if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: + config["suggested_questions_after_answer"]["enabled"] = False + + if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): + raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + + return config, ["suggested_questions_after_answer"] diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/apps/config_validators/text_to_speech.py new file mode 100644 index 00000000000000..efe34a8a3e902d --- /dev/null +++ b/api/core/apps/config_validators/text_to_speech.py @@ -0,0 +1,30 @@ +from typing import Tuple + + +class TextToSpeechValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for text to speech feature + + :param config: app model config args + """ + if not config.get("text_to_speech"): + config["text_to_speech"] = { + "enabled": False, + "voice": "", + "language": "" + } + + if not isinstance(config["text_to_speech"], dict): + raise ValueError("text_to_speech must be of dict type") + + if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: + config["text_to_speech"]["enabled"] = False + config["text_to_speech"]["voice"] = "" + config["text_to_speech"]["language"] = "" + + if not isinstance(config["text_to_speech"]["enabled"], bool): + raise ValueError("enabled in text_to_speech must be of boolean type") + + return config, ["text_to_speech"] diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/apps/config_validators/user_input_form.py new file mode 100644 index 00000000000000..7116c55afc636c --- /dev/null +++ b/api/core/apps/config_validators/user_input_form.py @@ -0,0 +1,62 @@ +import re +from typing import Tuple + + +class UserInputFormValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param config: app model config args + """ + if not config.get("user_input_form"): + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + + form_item = item[key] + if 'label' not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if 'variable' not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, " + "and cannot start with a number") + + variables.append(form_item["variable"]) + + if 'required' not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if 'options' not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item['default'] \ + and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + return config, ["user_input_form"] diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 34b6d62d51a47c..c1e0ecebe82a1f 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,528 +1,29 @@ -import re -import uuid - -from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import AppMode -from core.external_data_tool.factory import ExternalDataToolFactory -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers import model_provider_factory -from core.moderation.factory import ModerationFactory -from core.provider_manager import ProviderManager -from models.account import Account +from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator +from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator +from core.apps.app_config_validators.chat_app import ChatAppConfigValidator +from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator +from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator from models.model import AppMode -from services.dataset_service import DatasetService - -SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] class AppModelConfigService: - @classmethod - def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool: - # verify if the dataset ID exists - dataset = DatasetService.get_dataset(dataset_id) - - if not dataset: - return False - - if dataset.tenant_id != account.current_tenant_id: - return False - - return True - - @classmethod - def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict: - # 6. model.completion_params - if not isinstance(cp, dict): - raise ValueError("model.completion_params must be of object type") - - # stop - if 'stop' not in cp: - cp["stop"] = [] - elif not isinstance(cp["stop"], list): - raise ValueError("stop in model.completion_params must be of list type") - - if len(cp["stop"]) > 4: - raise ValueError("stop sequences must be less than 4") - - return cp @classmethod - def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict: - # opening_statement - if 'opening_statement' not in config or not config["opening_statement"]: - config["opening_statement"] = "" - - if not isinstance(config["opening_statement"], str): - raise ValueError("opening_statement must be of string type") - - # suggested_questions - if 'suggested_questions' not in config or not config["suggested_questions"]: - config["suggested_questions"] = [] - - if not isinstance(config["suggested_questions"], list): - raise ValueError("suggested_questions must be of list type") - - for question in config["suggested_questions"]: - if not isinstance(question, str): - raise ValueError("Elements in suggested_questions list must be of string type") - - # suggested_questions_after_answer - if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]: - config["suggested_questions_after_answer"] = { - "enabled": False - } - - if not isinstance(config["suggested_questions_after_answer"], dict): - raise ValueError("suggested_questions_after_answer must be of dict type") - - if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: - config["suggested_questions_after_answer"]["enabled"] = False - - if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): - raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") - - # speech_to_text - if 'speech_to_text' not in config or not config["speech_to_text"]: - config["speech_to_text"] = { - "enabled": False - } - - if not isinstance(config["speech_to_text"], dict): - raise ValueError("speech_to_text must be of dict type") - - if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: - config["speech_to_text"]["enabled"] = False - - if not isinstance(config["speech_to_text"]["enabled"], bool): - raise ValueError("enabled in speech_to_text must be of boolean type") - - # text_to_speech - if 'text_to_speech' not in config or not config["text_to_speech"]: - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } - - if not isinstance(config["text_to_speech"], dict): - raise ValueError("text_to_speech must be of dict type") - - if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: - config["text_to_speech"]["enabled"] = False - config["text_to_speech"]["voice"] = "" - config["text_to_speech"]["language"] = "" - - if not isinstance(config["text_to_speech"]["enabled"], bool): - raise ValueError("enabled in text_to_speech must be of boolean type") - - # return retriever resource - if 'retriever_resource' not in config or not config["retriever_resource"]: - config["retriever_resource"] = { - "enabled": False - } - - if not isinstance(config["retriever_resource"], dict): - raise ValueError("retriever_resource must be of dict type") - - if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: - config["retriever_resource"]["enabled"] = False - - if not isinstance(config["retriever_resource"]["enabled"], bool): - raise ValueError("enabled in retriever_resource must be of boolean type") - - # more_like_this - if 'more_like_this' not in config or not config["more_like_this"]: - config["more_like_this"] = { - "enabled": False - } - - if not isinstance(config["more_like_this"], dict): - raise ValueError("more_like_this must be of dict type") - - if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: - config["more_like_this"]["enabled"] = False - - if not isinstance(config["more_like_this"]["enabled"], bool): - raise ValueError("enabled in more_like_this must be of boolean type") - - # model - if 'model' not in config: - raise ValueError("model is required") - - if not isinstance(config["model"], dict): - raise ValueError("model must be of object type") - - # model.provider - provider_entities = model_provider_factory.get_providers() - model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: - raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") - - # model.name - if 'name' not in config["model"]: - raise ValueError("model.name is required") - - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM - ) - if not models: - raise ValueError("model.name must be in the specified model list") - - model_ids = [m.model for m in models] - if config["model"]["name"] not in model_ids: - raise ValueError("model.name must be in the specified model list") - - model_mode = None - for model in models: - if model.model == config["model"]["name"]: - model_mode = model.model_properties.get(ModelPropertyKey.MODE) - break - - # model.mode - if model_mode: - config['model']["mode"] = model_mode + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + if app_mode == AppMode.CHAT: + return ChatAppConfigValidator.config_validate(tenant_id, config) + elif app_mode == AppMode.AGENT_CHAT: + return AgentChatAppConfigValidator.config_validate(tenant_id, config) + elif app_mode == AppMode.COMPLETION: + return CompletionAppConfigValidator.config_validate(tenant_id, config) else: - config['model']["mode"] = "completion" - - # model.completion_params - if 'completion_params' not in config["model"]: - raise ValueError("model.completion_params is required") - - config["model"]["completion_params"] = cls.validate_model_completion_params( - config["model"]["completion_params"], - config["model"]["name"] - ) - - # user_input_form - if "user_input_form" not in config or not config["user_input_form"]: - config["user_input_form"] = [] - - if not isinstance(config["user_input_form"], list): - raise ValueError("user_input_form must be a list of objects") - - variables = [] - for item in config["user_input_form"]: - key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: - raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") - - form_item = item[key] - if 'label' not in form_item: - raise ValueError("label is required in user_input_form") - - if not isinstance(form_item["label"], str): - raise ValueError("label in user_input_form must be of string type") - - if 'variable' not in form_item: - raise ValueError("variable is required in user_input_form") - - if not isinstance(form_item["variable"], str): - raise ValueError("variable in user_input_form must be of string type") - - pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") - if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") - - variables.append(form_item["variable"]) - - if 'required' not in form_item or not form_item["required"]: - form_item["required"] = False - - if not isinstance(form_item["required"], bool): - raise ValueError("required in user_input_form must be of boolean type") - - if key == "select": - if 'options' not in form_item or not form_item["options"]: - form_item["options"] = [] - - if not isinstance(form_item["options"], list): - raise ValueError("options in user_input_form must be a list of strings") - - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: - raise ValueError("default value in user_input_form must be in the options list") - - # pre_prompt - if "pre_prompt" not in config or not config["pre_prompt"]: - config["pre_prompt"] = "" - - if not isinstance(config["pre_prompt"], str): - raise ValueError("pre_prompt must be of string type") - - # agent_mode - if "agent_mode" not in config or not config["agent_mode"]: - config["agent_mode"] = { - "enabled": False, - "tools": [] - } - - if not isinstance(config["agent_mode"], dict): - raise ValueError("agent_mode must be of object type") - - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False - - if not isinstance(config["agent_mode"]["enabled"], bool): - raise ValueError("enabled in agent_mode must be of boolean type") - - if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]: - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - - if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: - raise ValueError("strategy in agent_mode must be in the specified strategy list") - - if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]: - config["agent_mode"]["tools"] = [] - - if not isinstance(config["agent_mode"]["tools"], list): - raise ValueError("tools in agent_mode must be a list of objects") - - for tool in config["agent_mode"]["tools"]: - key = list(tool.keys())[0] - if key in SUPPORT_TOOLS: - # old style, use tool name as key - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False - - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") - - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") - - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not cls.is_dataset_exists(account, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") - else: - # latest style, use key-value pair - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - if "provider_type" not in tool: - raise ValueError("provider_type is required in agent_mode.tools") - if "provider_id" not in tool: - raise ValueError("provider_id is required in agent_mode.tools") - if "tool_name" not in tool: - raise ValueError("tool_name is required in agent_mode.tools") - if "tool_parameters" not in tool: - raise ValueError("tool_parameters is required in agent_mode.tools") - - # dataset_query_variable - cls.is_dataset_query_variable_valid(config, app_mode) - - # advanced prompt validation - cls.is_advanced_prompt_valid(config, app_mode) - - # external data tools validation - cls.is_external_data_tools_valid(tenant_id, config) - - # moderation validation - cls.is_moderation_valid(tenant_id, config) - - # file upload validation - cls.is_file_upload_valid(config) - - # Filter out extra parameters - filtered_config = { - "opening_statement": config["opening_statement"], - "suggested_questions": config["suggested_questions"], - "suggested_questions_after_answer": config["suggested_questions_after_answer"], - "speech_to_text": config["speech_to_text"], - "text_to_speech": config["text_to_speech"], - "retriever_resource": config["retriever_resource"], - "more_like_this": config["more_like_this"], - "sensitive_word_avoidance": config["sensitive_word_avoidance"], - "external_data_tools": config["external_data_tools"], - "model": { - "provider": config["model"]["provider"], - "name": config["model"]["name"], - "mode": config['model']["mode"], - "completion_params": config["model"]["completion_params"] - }, - "user_input_form": config["user_input_form"], - "dataset_query_variable": config.get('dataset_query_variable'), - "pre_prompt": config["pre_prompt"], - "agent_mode": config["agent_mode"], - "prompt_type": config["prompt_type"], - "chat_prompt_config": config["chat_prompt_config"], - "completion_prompt_config": config["completion_prompt_config"], - "dataset_configs": config["dataset_configs"], - "file_upload": config["file_upload"] - } - - return filtered_config - - @classmethod - def is_moderation_valid(cls, tenant_id: str, config: dict): - if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: - config["sensitive_word_avoidance"] = { - "enabled": False - } - - if not isinstance(config["sensitive_word_avoidance"], dict): - raise ValueError("sensitive_word_avoidance must be of dict type") - - if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: - config["sensitive_word_avoidance"]["enabled"] = False - - if not config["sensitive_word_avoidance"]["enabled"]: - return - - if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: - raise ValueError("sensitive_word_avoidance.type is required") - - type = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] - - ModerationFactory.validate_config( - name=type, - tenant_id=tenant_id, - config=config - ) - - @classmethod - def is_file_upload_valid(cls, config: dict): - if 'file_upload' not in config or not config["file_upload"]: - config["file_upload"] = {} - - if not isinstance(config["file_upload"], dict): - raise ValueError("file_upload must be of dict type") - - # check image config - if 'image' not in config["file_upload"] or not config["file_upload"]["image"]: - config["file_upload"]["image"] = {"enabled": False} - - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] - if number_limits < 1 or number_limits > 6: - raise ValueError("number_limits must be in [1, 6]") - - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: - raise ValueError("detail must be in ['high', 'low']") - - transfer_methods = config['file_upload']['image']['transfer_methods'] - if not isinstance(transfer_methods, list): - raise ValueError("transfer_methods must be of list type") - for method in transfer_methods: - if method not in ['remote_url', 'local_file']: - raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") - - @classmethod - def is_external_data_tools_valid(cls, tenant_id: str, config: dict): - if 'external_data_tools' not in config or not config["external_data_tools"]: - config["external_data_tools"] = [] - - if not isinstance(config["external_data_tools"], list): - raise ValueError("external_data_tools must be of list type") - - for tool in config["external_data_tools"]: - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - - if not tool["enabled"]: - continue - - if "type" not in tool or not tool["type"]: - raise ValueError("external_data_tools[].type is required") - - type = tool["type"] - config = tool["config"] - - ExternalDataToolFactory.validate_config( - name=type, - tenant_id=tenant_id, - config=config - ) - - @classmethod - def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: - # Only check when mode is completion - if mode != 'completion': - return - - agent_mode = config.get("agent_mode", {}) - tools = agent_mode.get("tools", []) - dataset_exists = "dataset" in str(tools) - - dataset_query_variable = config.get("dataset_query_variable") - - if dataset_exists and not dataset_query_variable: - raise ValueError("Dataset query variable is required when dataset is exist") + raise ValueError(f"Invalid app mode: {app_mode}") @classmethod - def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: - # prompt_type - if 'prompt_type' not in config or not config["prompt_type"]: - config["prompt_type"] = "simple" - - if config['prompt_type'] not in ['simple', 'advanced']: - raise ValueError("prompt_type must be in ['simple', 'advanced']") - - # chat_prompt_config - if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: - config["chat_prompt_config"] = {} - - if not isinstance(config["chat_prompt_config"], dict): - raise ValueError("chat_prompt_config must be of object type") - - # completion_prompt_config - if 'completion_prompt_config' not in config or not config["completion_prompt_config"]: - config["completion_prompt_config"] = {} - - if not isinstance(config["completion_prompt_config"], dict): - raise ValueError("completion_prompt_config must be of object type") - - # dataset_configs - if 'dataset_configs' not in config or not config["dataset_configs"]: - config["dataset_configs"] = {'retrieval_model': 'single'} - - if 'datasets' not in config["dataset_configs"] or not config["dataset_configs"]["datasets"]: - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if config["dataset_configs"]['retrieval_model'] == 'multiple': - if not config["dataset_configs"]['reranking_model']: - raise ValueError("reranking_model has not been set") - if not isinstance(config["dataset_configs"]['reranking_model'], dict): - raise ValueError("reranking_model must be of object type") - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if config['prompt_type'] == 'advanced': - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") - - if config['model']["mode"] not in ['chat', 'completion']: - raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") - - if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion": - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] - - if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' - - if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' - - if config['model']["mode"] == "chat": - prompt_list = config['chat_prompt_config']['prompt'] - - if len(prompt_list) > 10: - raise ValueError("prompt messages must be less than 10") + def validate_features(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + if app_mode == AppMode.ADVANCED_CHAT: + return AdvancedChatAppConfigValidator.config_validate(tenant_id, config) + elif app_mode == AppMode.WORKFLOW: + return WorkflowAppConfigValidator.config_validate(tenant_id, config) + else: + raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/completion_service.py b/api/services/completion_service.py index cbfbe9ef416b63..6dd729694ba2cf 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -5,10 +5,11 @@ from sqlalchemy import and_ from core.application_manager import ApplicationManager +from core.apps.config_validators.model import ModelValidator from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message +from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message, AppMode from services.app_model_config_service import AppModelConfigService from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError @@ -88,9 +89,8 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, if 'completion_params' not in args['model_config']['model']: raise ValueError('model_config.model.completion_params is required') - completion_params = AppModelConfigService.validate_model_completion_params( - cp=args['model_config']['model']['completion_params'], - model_name=app_model_config.model_dict["name"] + completion_params = ModelValidator.validate_model_completion_params( + cp=args['model_config']['model']['completion_params'] ) app_model_config_model = app_model_config.model_dict @@ -115,9 +115,8 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, # validate config model_config = AppModelConfigService.validate_configuration( tenant_id=app_model.tenant_id, - account=user, config=args['model_config'], - app_mode=app_model.mode + app_mode=AppMode.value_of(app_model.mode) ) app_model_config = AppModelConfig( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ae6e4c46d3a750..5a9234c70a4c26 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -96,7 +96,7 @@ def publish_workflow(self, app_model: App, if not draft_workflow: raise ValueError('No valid workflow found.') - # TODO check if the workflow is valid + # TODO check if the workflow is valid, basic check # create new workflow workflow = Workflow( From d741527ae4b6f7257c9ceb243f8c2190fa226632 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 22:16:36 +0800 Subject: [PATCH 041/160] lint --- api/controllers/console/app/model_config.py | 2 +- api/core/apps/config_validators/agent.py | 3 +-- api/core/apps/config_validators/dataset.py | 3 +-- api/core/apps/config_validators/external_data_tools.py | 3 +-- api/core/apps/config_validators/file_upload.py | 3 +-- api/core/apps/config_validators/model.py | 5 ++--- api/core/apps/config_validators/moderation.py | 3 +-- api/core/apps/config_validators/more_like_this.py | 3 +-- api/core/apps/config_validators/opening_statement.py | 3 +-- api/core/apps/config_validators/prompt.py | 3 +-- api/core/apps/config_validators/retriever_resource.py | 3 +-- api/core/apps/config_validators/speech_to_text.py | 3 +-- api/core/apps/config_validators/suggested_questions.py | 3 +-- api/core/apps/config_validators/text_to_speech.py | 3 +-- api/core/apps/config_validators/user_input_form.py | 3 +-- api/services/completion_service.py | 2 +- 16 files changed, 17 insertions(+), 31 deletions(-) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 0ae9f5e5464080..d822f859bc35e8 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -14,7 +14,7 @@ from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required -from models.model import AppModelConfig, AppMode +from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py index 69f93380801b51..c6584d2903e246 100644 --- a/api/core/apps/config_validators/agent.py +++ b/api/core/apps/config_validators/agent.py @@ -1,5 +1,4 @@ import uuid -from typing import Tuple from core.agent.agent_executor import PlanningStrategy from core.apps.config_validators.dataset import DatasetValidator @@ -9,7 +8,7 @@ class AgentValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for agent feature diff --git a/api/core/apps/config_validators/dataset.py b/api/core/apps/config_validators/dataset.py index 32db038c21a0d1..9846f9085c0c2a 100644 --- a/api/core/apps/config_validators/dataset.py +++ b/api/core/apps/config_validators/dataset.py @@ -1,5 +1,4 @@ import uuid -from typing import Tuple from core.agent.agent_executor import PlanningStrategy from models.model import AppMode @@ -8,7 +7,7 @@ class DatasetValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for dataset feature diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/apps/config_validators/external_data_tools.py index 5412366a897c3c..02ecc8d71598c5 100644 --- a/api/core/apps/config_validators/external_data_tools.py +++ b/api/core/apps/config_validators/external_data_tools.py @@ -1,11 +1,10 @@ -from typing import Tuple from core.external_data_tool.factory import ExternalDataToolFactory class ExternalDataToolsValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for external data fetch feature diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/apps/config_validators/file_upload.py index f9adbfdf7df3a1..419465bd5119ab 100644 --- a/api/core/apps/config_validators/file_upload.py +++ b/api/core/apps/config_validators/file_upload.py @@ -1,9 +1,8 @@ -from typing import Tuple class FileUploadValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for file upload feature diff --git a/api/core/apps/config_validators/model.py b/api/core/apps/config_validators/model.py index 091eec4683a37f..1d86fbaf04e836 100644 --- a/api/core/apps/config_validators/model.py +++ b/api/core/apps/config_validators/model.py @@ -1,13 +1,12 @@ -from typing import Tuple -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.provider_manager import ProviderManager class ModelValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for model config diff --git a/api/core/apps/config_validators/moderation.py b/api/core/apps/config_validators/moderation.py index 1962f87aa9b181..4813385588905b 100644 --- a/api/core/apps/config_validators/moderation.py +++ b/api/core/apps/config_validators/moderation.py @@ -1,5 +1,4 @@ import logging -from typing import Tuple from core.moderation.factory import ModerationFactory @@ -8,7 +7,7 @@ class ModerationValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id, config: dict) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = { "enabled": False diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/apps/config_validators/more_like_this.py index 60dc4a056255fa..1c1bac9de64431 100644 --- a/api/core/apps/config_validators/more_like_this.py +++ b/api/core/apps/config_validators/more_like_this.py @@ -1,9 +1,8 @@ -from typing import Tuple class MoreLikeThisValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for more like this feature diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/apps/config_validators/opening_statement.py index 3f69e0e9469430..f919230e0d1611 100644 --- a/api/core/apps/config_validators/opening_statement.py +++ b/api/core/apps/config_validators/opening_statement.py @@ -1,9 +1,8 @@ -from typing import Tuple class OpeningStatementValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for opening statement feature diff --git a/api/core/apps/config_validators/prompt.py b/api/core/apps/config_validators/prompt.py index 815706b10b9b3e..288a5234155b2c 100644 --- a/api/core/apps/config_validators/prompt.py +++ b/api/core/apps/config_validators/prompt.py @@ -1,4 +1,3 @@ -from typing import Tuple from core.entities.application_entities import PromptTemplateEntity from core.prompt.simple_prompt_transform import ModelMode @@ -7,7 +6,7 @@ class PromptValidator: @classmethod - def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ Validate pre_prompt and set defaults for prompt feature depending on the config['model'] diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/apps/config_validators/retriever_resource.py index a8bcd60abef5c9..32725c74328fbd 100644 --- a/api/core/apps/config_validators/retriever_resource.py +++ b/api/core/apps/config_validators/retriever_resource.py @@ -1,9 +1,8 @@ -from typing import Tuple class RetrieverResourceValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for retriever resource feature diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/apps/config_validators/speech_to_text.py index 577bef0e59f8fb..92a1b25ae69085 100644 --- a/api/core/apps/config_validators/speech_to_text.py +++ b/api/core/apps/config_validators/speech_to_text.py @@ -1,9 +1,8 @@ -from typing import Tuple class SpeechToTextValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for speech to text feature diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/apps/config_validators/suggested_questions.py index 938b66bb6ecb38..9161b316781489 100644 --- a/api/core/apps/config_validators/suggested_questions.py +++ b/api/core/apps/config_validators/suggested_questions.py @@ -1,9 +1,8 @@ -from typing import Tuple class SuggestedQuestionsValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for suggested questions feature diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/apps/config_validators/text_to_speech.py index efe34a8a3e902d..182a912d52c692 100644 --- a/api/core/apps/config_validators/text_to_speech.py +++ b/api/core/apps/config_validators/text_to_speech.py @@ -1,9 +1,8 @@ -from typing import Tuple class TextToSpeechValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for text to speech feature diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/apps/config_validators/user_input_form.py index 7116c55afc636c..249d6745ae7bf3 100644 --- a/api/core/apps/config_validators/user_input_form.py +++ b/api/core/apps/config_validators/user_input_form.py @@ -1,10 +1,9 @@ import re -from typing import Tuple class UserInputFormValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for user input form diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 6dd729694ba2cf..9acd62b997f044 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -9,7 +9,7 @@ from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message, AppMode +from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message from services.app_model_config_service import AppModelConfigService from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError From 3badc4423a6fb91642b2263c68cc4442d06a3787 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 12:22:30 +0800 Subject: [PATCH 042/160] fix: wrong default model parameters when creating app --- api/constants/model_template.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index ca0b7549897bc0..61aab64d8a298e 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -23,13 +23,7 @@ "provider": "openai", "name": "gpt-4", "mode": "chat", - "completion_params": { - "max_tokens": 512, - "temperature": 1, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } + "completion_params": {} } } }, @@ -46,13 +40,7 @@ "provider": "openai", "name": "gpt-4", "mode": "chat", - "completion_params": { - "max_tokens": 512, - "temperature": 1, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } + "completion_params": {} } } }, @@ -69,16 +57,8 @@ "provider": "openai", "name": "gpt-4", "mode": "chat", - "completion_params": { - "max_tokens": 512, - "temperature": 1, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } + "completion_params": {} } } - }, + } } - - From 896c20021156bd3877b844f122375e01c92ba4b7 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 13:24:26 +0800 Subject: [PATCH 043/160] fix import problem --- api/core/apps/config_validators/agent.py | 2 +- api/core/apps/config_validators/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py index c6584d2903e246..b445aedbf868af 100644 --- a/api/core/apps/config_validators/agent.py +++ b/api/core/apps/config_validators/agent.py @@ -1,7 +1,7 @@ import uuid -from core.agent.agent_executor import PlanningStrategy from core.apps.config_validators.dataset import DatasetValidator +from core.entities.agent_entities import PlanningStrategy OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/core/apps/config_validators/dataset.py b/api/core/apps/config_validators/dataset.py index 9846f9085c0c2a..fb5b64832073ae 100644 --- a/api/core/apps/config_validators/dataset.py +++ b/api/core/apps/config_validators/dataset.py @@ -1,6 +1,6 @@ import uuid -from core.agent.agent_executor import PlanningStrategy +from core.entities.agent_entities import PlanningStrategy from models.model import AppMode from services.dataset_service import DatasetService From 799db69e4f334a20cbbfad540b518bffc4b698d9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 17:33:52 +0800 Subject: [PATCH 044/160] refactor app --- api/controllers/console/app/completion.py | 6 +- api/controllers/console/app/generator.py | 2 +- api/controllers/console/explore/completion.py | 6 +- api/controllers/service_api/app/completion.py | 6 +- api/controllers/web/completion.py | 6 +- api/core/{app_runner => agent}/__init__.py | 0 .../base_agent_runner.py} | 8 +- .../cot_agent_runner.py} | 6 +- .../fc_agent_runner.py} | 6 +- api/core/{apps => app}/__init__.py | 0 .../advanced_chat}/__init__.py | 0 .../advanced_chat/config_validator.py} | 14 +- .../agent_chat}/__init__.py | 0 .../agent_chat/app_runner.py} | 19 +- api/core/app/agent_chat/config_validator.py | 162 +++++++ api/core/app/app_manager.py | 382 +++++++++++++++ .../app_orchestration_config_converter.py} | 436 ++---------------- .../app_queue_manager.py} | 6 +- .../app_runner.py => app/base_app_runner.py} | 26 +- api/core/{features => app/chat}/__init__.py | 0 .../chat/app_runner.py} | 16 +- .../chat/config_validator.py} | 26 +- .../completion}/__init__.py | 0 api/core/app/completion/app_runner.py | 266 +++++++++++ .../completion/config_validator.py} | 20 +- .../agent => app/features}/__init__.py | 0 .../features/annotation_reply}/__init__.py | 0 .../annotation_reply}/annotation_reply.py | 0 .../features/hosting_moderation/__init__.py | 0 .../hosting_moderation}/hosting_moderation.py | 0 .../generate_task_pipeline.py | 12 +- api/core/app/validators/__init__.py | 0 .../validators/dataset_retrieval.py} | 0 .../validators/external_data_fetch.py} | 2 +- .../validators}/file_upload.py | 0 .../validators/model_validator.py} | 0 .../validators}/moderation.py | 0 .../validators}/more_like_this.py | 0 .../validators}/opening_statement.py | 0 .../validators}/prompt.py | 0 .../validators}/retriever_resource.py | 0 .../validators}/speech_to_text.py | 0 .../validators}/suggested_questions.py | 0 .../validators}/text_to_speech.py | 0 .../validators}/user_input_form.py | 0 api/core/app/workflow/__init__.py | 0 .../workflow/config_validator.py} | 6 +- .../app_config_validators/agent_chat_app.py | 82 ---- api/core/apps/config_validators/agent.py | 81 ---- .../agent_loop_gather_callback_handler.py | 4 +- .../index_tool_callback_handler.py | 4 +- .../external_data_fetch.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/llm_generator/__init__.py | 0 .../llm_generator.py | 8 +- .../llm_generator/output_parser/__init__.py | 0 .../output_parser/rule_config_generator.py | 2 +- .../suggested_questions_after_answer.py | 2 +- api/core/{prompt => llm_generator}/prompts.py | 0 .../input_moderation.py} | 2 +- .../output_moderation.py} | 4 +- api/core/prompt/__init__.py | 0 api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/prompt/prompt_templates/__init__.py | 0 .../advanced_prompt_templates.py | 0 .../baichuan_chat.json | 0 .../baichuan_completion.json | 0 .../common_chat.json | 0 .../common_completion.json | 0 api/core/prompt/simple_prompt_transform.py | 4 +- api/core/prompt/utils/__init__.py | 0 .../prompt_template_parser.py} | 0 .../processor/qa_index_processor.py | 2 +- api/core/rag/retrieval/__init__.py | 0 api/core/rag/retrieval/agent/__init__.py | 0 .../retrieval}/agent/agent_llm_callback.py | 0 .../retrieval}/agent/fake_llm.py | 0 .../retrieval}/agent/llm_chain.py | 4 +- .../agent/multi_dataset_router_agent.py | 2 +- .../retrieval/agent/output_parser/__init__.py | 0 .../agent/output_parser/structured_chat.py | 0 .../structed_multi_dataset_router_agent.py | 2 +- .../agent_based_dataset_executor.py | 8 +- .../retrieval}/dataset_retrieval.py | 4 +- api/core/tools/tool/dataset_retriever_tool.py | 4 +- ...rsation_name_when_first_message_created.py | 2 +- api/models/model.py | 18 +- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 10 +- api/services/completion_service.py | 8 +- api/services/conversation_service.py | 2 +- api/services/message_service.py | 2 +- api/services/workflow/workflow_converter.py | 4 +- .../prompt/test_advanced_prompt_transform.py | 2 +- 94 files changed, 992 insertions(+), 722 deletions(-) rename api/core/{app_runner => agent}/__init__.py (100%) rename api/core/{features/assistant_base_runner.py => agent/base_agent_runner.py} (99%) rename api/core/{features/assistant_cot_runner.py => agent/cot_agent_runner.py} (99%) rename api/core/{features/assistant_fc_runner.py => agent/fc_agent_runner.py} (98%) rename api/core/{apps => app}/__init__.py (100%) rename api/core/{apps/app_config_validators => app/advanced_chat}/__init__.py (100%) rename api/core/{apps/app_config_validators/advanced_chat_app.py => app/advanced_chat/config_validator.py} (77%) rename api/core/{apps/config_validators => app/agent_chat}/__init__.py (100%) rename api/core/{app_runner/assistant_app_runner.py => app/agent_chat/app_runner.py} (95%) create mode 100644 api/core/app/agent_chat/config_validator.py create mode 100644 api/core/app/app_manager.py rename api/core/{application_manager.py => app/app_orchestration_config_converter.py} (52%) rename api/core/{application_queue_manager.py => app/app_queue_manager.py} (97%) rename api/core/{app_runner/app_runner.py => app/base_app_runner.py} (94%) rename api/core/{features => app/chat}/__init__.py (100%) rename api/core/{app_runner/basic_app_runner.py => app/chat/app_runner.py} (95%) rename api/core/{apps/app_config_validators/chat_app.py => app/chat/config_validator.py} (75%) rename api/core/{features/dataset_retrieval => app/completion}/__init__.py (100%) create mode 100644 api/core/app/completion/app_runner.py rename api/core/{apps/app_config_validators/completion_app.py => app/completion/config_validator.py} (76%) rename api/core/{features/dataset_retrieval/agent => app/features}/__init__.py (100%) rename api/core/{features/dataset_retrieval/agent/output_parser => app/features/annotation_reply}/__init__.py (100%) rename api/core/{features => app/features/annotation_reply}/annotation_reply.py (100%) create mode 100644 api/core/app/features/hosting_moderation/__init__.py rename api/core/{features => app/features/hosting_moderation}/hosting_moderation.py (100%) rename api/core/{app_runner => app}/generate_task_pipeline.py (98%) create mode 100644 api/core/app/validators/__init__.py rename api/core/{apps/config_validators/dataset.py => app/validators/dataset_retrieval.py} (100%) rename api/core/{apps/config_validators/external_data_tools.py => app/validators/external_data_fetch.py} (97%) rename api/core/{apps/config_validators => app/validators}/file_upload.py (100%) rename api/core/{apps/config_validators/model.py => app/validators/model_validator.py} (100%) rename api/core/{apps/config_validators => app/validators}/moderation.py (100%) rename api/core/{apps/config_validators => app/validators}/more_like_this.py (100%) rename api/core/{apps/config_validators => app/validators}/opening_statement.py (100%) rename api/core/{apps/config_validators => app/validators}/prompt.py (100%) rename api/core/{apps/config_validators => app/validators}/retriever_resource.py (100%) rename api/core/{apps/config_validators => app/validators}/speech_to_text.py (100%) rename api/core/{apps/config_validators => app/validators}/suggested_questions.py (100%) rename api/core/{apps/config_validators => app/validators}/text_to_speech.py (100%) rename api/core/{apps/config_validators => app/validators}/user_input_form.py (100%) create mode 100644 api/core/app/workflow/__init__.py rename api/core/{apps/app_config_validators/workflow_app.py => app/workflow/config_validator.py} (83%) delete mode 100644 api/core/apps/app_config_validators/agent_chat_app.py delete mode 100644 api/core/apps/config_validators/agent.py rename api/core/{features => external_data_tool}/external_data_fetch.py (98%) create mode 100644 api/core/llm_generator/__init__.py rename api/core/{generator => llm_generator}/llm_generator.py (93%) create mode 100644 api/core/llm_generator/output_parser/__init__.py rename api/core/{prompt => llm_generator}/output_parser/rule_config_generator.py (94%) rename api/core/{prompt => llm_generator}/output_parser/suggested_questions_after_answer.py (87%) rename api/core/{prompt => llm_generator}/prompts.py (100%) rename api/core/{features/moderation.py => moderation/input_moderation.py} (98%) rename api/core/{app_runner/moderation_handler.py => moderation/output_moderation.py} (97%) create mode 100644 api/core/prompt/__init__.py create mode 100644 api/core/prompt/prompt_templates/__init__.py rename api/core/prompt/{ => prompt_templates}/advanced_prompt_templates.py (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/baichuan_chat.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/baichuan_completion.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/common_chat.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/common_completion.json (100%) create mode 100644 api/core/prompt/utils/__init__.py rename api/core/prompt/{prompt_template.py => utils/prompt_template_parser.py} (100%) create mode 100644 api/core/rag/retrieval/__init__.py create mode 100644 api/core/rag/retrieval/agent/__init__.py rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/agent_llm_callback.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/fake_llm.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/llm_chain.py (91%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/multi_dataset_router_agent.py (98%) create mode 100644 api/core/rag/retrieval/agent/output_parser/__init__.py rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/output_parser/structured_chat.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/structed_multi_dataset_router_agent.py (99%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent_based_dataset_executor.py (92%) rename api/core/{features/dataset_retrieval => rag/retrieval}/dataset_retrieval.py (98%) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index e62475308faf04..0632c0439b29b5 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -94,7 +94,7 @@ class CompletionMessageStopApi(Resource): def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 @@ -172,7 +172,7 @@ class ChatMessageStopApi(Resource): def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 3ec932b5f11ead..ee02fc18465c5d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -11,7 +11,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 6406d5b3b05f6e..22ea4bbac242ee 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -21,7 +21,7 @@ ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -90,7 +90,7 @@ def post(self, installed_app, task_id): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 @@ -154,7 +154,7 @@ def post(self, installed_app, task_id): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c6cfb24378c482..fd4ce831b37081 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -19,7 +19,7 @@ ProviderQuotaExceededError, ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -85,7 +85,7 @@ def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'completion': raise AppUnavailableError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 @@ -147,7 +147,7 @@ def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 61d4f8c36232ba..fd94ec7646bd74 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -20,7 +20,7 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -84,7 +84,7 @@ def post(self, app_model, end_user, task_id): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 @@ -144,7 +144,7 @@ def post(self, app_model, end_user, task_id): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 diff --git a/api/core/app_runner/__init__.py b/api/core/agent/__init__.py similarity index 100% rename from api/core/app_runner/__init__.py rename to api/core/agent/__init__.py diff --git a/api/core/features/assistant_base_runner.py b/api/core/agent/base_agent_runner.py similarity index 99% rename from api/core/features/assistant_base_runner.py rename to api/core/agent/base_agent_runner.py index 1d9541070f881f..0658124d1422ff 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,8 +5,8 @@ from mimetypes import guess_extension from typing import Optional, Union, cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( @@ -48,13 +48,13 @@ logger = logging.getLogger(__name__) -class BaseAssistantApplicationRunner(AppRunner): +class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, application_generate_entity: ApplicationGenerateEntity, app_orchestration_config: AppOrchestrationConfigEntity, model_config: ModelConfigEntity, config: AgentEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, diff --git a/api/core/features/assistant_cot_runner.py b/api/core/agent/cot_agent_runner.py similarity index 99% rename from api/core/features/assistant_cot_runner.py rename to api/core/agent/cot_agent_runner.py index 3762ddcf62e7c5..152e4457955fae 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -3,9 +3,9 @@ from collections.abc import Generator from typing import Literal, Union -from core.application_queue_manager import PublishFrom +from core.app.app_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -262,7 +262,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_call_args = json.loads(tool_call_args) except json.JSONDecodeError: pass - + tool_response = tool_instance.invoke( user_id=self.user_id, tool_parameters=tool_call_args diff --git a/api/core/features/assistant_fc_runner.py b/api/core/agent/fc_agent_runner.py similarity index 98% rename from api/core/features/assistant_fc_runner.py rename to api/core/agent/fc_agent_runner.py index 391e040c53d32b..0cf0d3762cea43 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -3,8 +3,8 @@ from collections.abc import Generator from typing import Any, Union -from core.application_queue_manager import PublishFrom -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.app.app_queue_manager import PublishFrom +from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): +class FunctionCallAgentRunner(BaseAgentRunner): def run(self, conversation: Conversation, message: Message, query: str, diff --git a/api/core/apps/__init__.py b/api/core/app/__init__.py similarity index 100% rename from api/core/apps/__init__.py rename to api/core/app/__init__.py diff --git a/api/core/apps/app_config_validators/__init__.py b/api/core/app/advanced_chat/__init__.py similarity index 100% rename from api/core/apps/app_config_validators/__init__.py rename to api/core/app/advanced_chat/__init__.py diff --git a/api/core/apps/app_config_validators/advanced_chat_app.py b/api/core/app/advanced_chat/config_validator.py similarity index 77% rename from api/core/apps/app_config_validators/advanced_chat_app.py rename to api/core/app/advanced_chat/config_validator.py index dc7664b844f68c..39c00c028ef2ee 100644 --- a/api/core/apps/app_config_validators/advanced_chat_app.py +++ b/api/core/app/advanced_chat/config_validator.py @@ -1,10 +1,10 @@ -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator class AdvancedChatAppConfigValidator: diff --git a/api/core/apps/config_validators/__init__.py b/api/core/app/agent_chat/__init__.py similarity index 100% rename from api/core/apps/config_validators/__init__.py rename to api/core/app/agent_chat/__init__.py diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app/agent_chat/app_runner.py similarity index 95% rename from api/core/app_runner/assistant_app_runner.py rename to api/core/app/agent_chat/app_runner.py index 655a5a1c7c811d..b046e935a52a8a 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app/agent_chat/app_runner.py @@ -1,11 +1,11 @@ import logging from typing import cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity -from core.features.assistant_cot_runner import AssistantCotApplicationRunner -from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -19,12 +19,13 @@ logger = logging.getLogger(__name__) -class AssistantApplicationRunner(AppRunner): + +class AgentChatAppRunner(AppRunner): """ - Assistant Application Runner + Agent Application Runner """ def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -201,7 +202,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - assistant_cot_runner = AssistantCotApplicationRunner( + assistant_cot_runner = CotAgentRunner( tenant_id=application_generate_entity.tenant_id, application_generate_entity=application_generate_entity, app_orchestration_config=app_orchestration_config, @@ -223,7 +224,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, inputs=inputs, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - assistant_fc_runner = AssistantFunctionCallApplicationRunner( + assistant_fc_runner = FunctionCallAgentRunner( tenant_id=application_generate_entity.tenant_id, application_generate_entity=application_generate_entity, app_orchestration_config=app_orchestration_config, diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/agent_chat/config_validator.py new file mode 100644 index 00000000000000..6596b19f9945e7 --- /dev/null +++ b/api/core/app/agent_chat/config_validator.py @@ -0,0 +1,162 @@ +import uuid + +from core.entities.agent_entities import PlanningStrategy +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + + @classmethod + def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate agent_mode and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [member.value for member in + list(PlanningStrategy.__members__.values())]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py new file mode 100644 index 00000000000000..0819ed864baab0 --- /dev/null +++ b/api/core/app/app_manager.py @@ -0,0 +1,382 @@ +import json +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Optional, Union, cast + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.agent_chat.app_runner import AgentChatAppRunner +from core.app.chat.app_runner import ChatAppRunner +from core.app.generate_task_pipeline import GenerateTaskPipeline +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.entities.application_entities import ( + ApplicationGenerateEntity, + InvokeFrom, +) +from core.file.file_obj import FileObj +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message, MessageFile + +logger = logging.getLogger(__name__) + + +class AppManager: + """ + This class is responsible for managing application + """ + + def generate(self, tenant_id: str, + app_id: str, + app_model_config_id: str, + app_model_config_dict: dict, + app_model_config_override: bool, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + inputs: dict[str, str], + query: Optional[str] = None, + files: Optional[list[FileObj]] = None, + conversation: Optional[Conversation] = None, + stream: bool = False, + extras: Optional[dict[str, Any]] = None) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param tenant_id: workspace ID + :param app_id: app ID + :param app_model_config_id: app model config id + :param app_model_config_dict: app model config dict + :param app_model_config_override: app model config override + :param user: account or end user + :param invoke_from: invoke from source + :param inputs: inputs + :param query: query + :param files: file obj list + :param conversation: conversation + :param stream: is stream + :param extras: extras + """ + # init task id + task_id = str(uuid.uuid4()) + + # init application generate entity + application_generate_entity = ApplicationGenerateEntity( + task_id=task_id, + tenant_id=tenant_id, + app_id=app_id, + app_model_config_id=app_model_config_id, + app_model_config_dict=app_model_config_dict, + app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict( + tenant_id=tenant_id, + app_model_config_dict=app_model_config_dict + ), + app_model_config_override=app_model_config_override, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else inputs, + query=query.replace('\x00', '') if query else None, + files=files if files else [], + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + if not stream and application_generate_entity.app_orchestration_config_entity.agent: + raise ValueError("Agent app is not supported in blocking mode.") + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + if application_generate_entity.app_orchestration_config_entity.agent: + # agent app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + else: + # basic app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = GenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() + + def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ + -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :return: + """ + app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity + + model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_schema = model_type_instance.get_model_schema( + model=app_orchestration_config_entity.model_config.model, + credentials=app_orchestration_config_entity.model_config.credentials + ) + + app_record = (db.session.query(App) + .filter(App.id == application_generate_entity.app_id).first()) + + app_mode = app_record.mode + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + from_source = 'api' + end_user_id = application_generate_entity.user_id + else: + from_source = 'console' + account_id = application_generate_entity.user_id + + override_model_configs = None + if application_generate_entity.app_model_config_override: + override_model_configs = application_generate_entity.app_model_config_dict + + introduction = '' + if app_mode == 'chat': + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not application_generate_entity.conversation_id: + conversation = Conversation( + app_id=app_record.id, + app_model_config_id=application_generate_entity.app_model_config_id, + model_provider=app_orchestration_config_entity.model_config.provider, + model_id=app_orchestration_config_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_mode, + name='New conversation', + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status='normal', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + else: + conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == application_generate_entity.conversation_id, + Conversation.app_id == app_record.id + ).first() + ) + + currency = model_schema.pricing.currency if model_schema.pricing else 'USD' + + message = Message( + app_id=app_record.id, + model_provider=app_orchestration_config_entity.model_config.provider, + model_id=app_orchestration_config_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency=currency, + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + agent_based=app_orchestration_config_entity.agent is not None + ) + + db.session.add(message) + db.session.commit() + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + belongs_to='user', + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if account_id else 'end_user'), + created_by=account_id or end_user_id, + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity + introduction = app_orchestration_config_entity.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction + + def _get_conversation(self, conversation_id: str) -> Conversation: + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + return conversation + + def _get_message(self, message_id: str) -> Message: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = ( + db.session.query(Message) + .filter(Message.id == message_id) + .first() + ) + + return message diff --git a/api/core/application_manager.py b/api/core/app/app_orchestration_config_converter.py similarity index 52% rename from api/core/application_manager.py rename to api/core/app/app_orchestration_config_converter.py index ea0c85427d4811..ddf49949a3bfb5 100644 --- a/api/core/application_manager.py +++ b/api/core/app/app_orchestration_config_converter.py @@ -1,241 +1,21 @@ -import json -import logging -import threading -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union, cast - -from flask import Flask, current_app -from pydantic import ValidationError - -from core.app_runner.assistant_app_runner import AssistantApplicationRunner -from core.app_runner.basic_app_runner import BasicApplicationRunner -from core.app_runner.generate_task_pipeline import GenerateTaskPipeline -from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom -from core.entities.application_entities import ( - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - AgentEntity, - AgentPromptEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, - SensitiveWordAvoidanceEntity, - TextToSpeechEntity, - VariableEntity, -) +from typing import cast + +from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \ + TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \ + ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \ + AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileObj +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager from core.tools.prompt.template import REACT_PROMPT_TEMPLATES -from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile - -logger = logging.getLogger(__name__) - - -class ApplicationManager: - """ - This class is responsible for managing application - """ - - def generate(self, tenant_id: str, - app_id: str, - app_model_config_id: str, - app_model_config_dict: dict, - app_model_config_override: bool, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - inputs: dict[str, str], - query: Optional[str] = None, - files: Optional[list[FileObj]] = None, - conversation: Optional[Conversation] = None, - stream: bool = False, - extras: Optional[dict[str, Any]] = None) \ - -> Union[dict, Generator]: - """ - Generate App response. - - :param tenant_id: workspace ID - :param app_id: app ID - :param app_model_config_id: app model config id - :param app_model_config_dict: app model config dict - :param app_model_config_override: app model config override - :param user: account or end user - :param invoke_from: invoke from source - :param inputs: inputs - :param query: query - :param files: file obj list - :param conversation: conversation - :param stream: is stream - :param extras: extras - """ - # init task id - task_id = str(uuid.uuid4()) - - # init application generate entity - application_generate_entity = ApplicationGenerateEntity( - task_id=task_id, - tenant_id=tenant_id, - app_id=app_id, - app_model_config_id=app_model_config_id, - app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self.convert_from_app_model_config_dict( - tenant_id=tenant_id, - app_model_config_dict=app_model_config_dict - ), - app_model_config_override=app_model_config_override, - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else inputs, - query=query.replace('\x00', '') if query else None, - files=files if files else [], - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras - ) - - if not stream and application_generate_entity.app_orchestration_config_entity.agent: - raise ValueError("Agent app is not supported in blocking mode.") - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = ApplicationQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) - - worker_thread.start() - - # return response or stream generator - return self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - stream=stream - ) - - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation_id: str, - message_id: str) -> None: - """ - Generate worker in a new thread. - :param flask_app: Flask app - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation_id: conversation ID - :param message_id: message ID - :return: - """ - with flask_app.app_context(): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - if application_generate_entity.app_orchestration_config_entity.agent: - # agent app - runner = AssistantApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - else: - # basic app - runner = BasicApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - except ConversationTaskStoppedException: - pass - except InvokeAuthorizationError: - queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER - ) - except ValidationError as e: - logger.exception("Validation Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except Exception as e: - logger.exception("Unknown Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() - - def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: - """ - Handle response. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param stream: is stream - :return: - """ - # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - try: - return generate_task_pipeline.process(stream=stream) - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() - else: - logger.exception(e) - raise e - def convert_from_app_model_config_dict(self, tenant_id: str, +class AppOrchestrationConfigConverter: + @classmethod + def convert_from_app_model_config_dict(cls, tenant_id: str, app_model_config_dict: dict, skip_check: bool = False) \ -> AppOrchestrationConfigEntity: @@ -394,7 +174,7 @@ def convert_from_app_model_config_dict(self, tenant_id: str, ) properties['variables'] = [] - + # variables and external_data_tools for variable in copy_app_model_config_dict.get('user_input_form', []): typ = list(variable.keys())[0] @@ -444,7 +224,7 @@ def convert_from_app_model_config_dict(self, tenant_id: str, show_retrieve_source = True properties['show_retrieve_source'] = show_retrieve_source - + dataset_ids = [] if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { @@ -452,26 +232,23 @@ def convert_from_app_model_config_dict(self, tenant_id: str, 'datasets': [] }) - for dataset in datasets.get('datasets', []): keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != 'dataset': continue dataset = dataset['dataset'] - + if 'enabled' not in dataset or not dataset['enabled']: continue - + dataset_id = dataset.get('id', None) if dataset_id: dataset_ids.append(dataset_id) - else: - datasets = {'strategy': 'router', 'datasets': []} if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ and 'enabled' in copy_app_model_config_dict['agent_mode'] \ and copy_app_model_config_dict['agent_mode']['enabled']: - + agent_dict = copy_app_model_config_dict.get('agent_mode', {}) agent_strategy = agent_dict.get('strategy', 'cot') @@ -515,7 +292,7 @@ def convert_from_app_model_config_dict(self, tenant_id: str, dataset_id = tool_item['id'] dataset_ids.append(dataset_id) - + if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: agent_prompt = agent_dict.get('prompt', None) or {} @@ -523,13 +300,18 @@ def convert_from_app_model_config_dict(self, tenant_id: str, model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') if model_mode == 'completion': agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['completion'][ + 'agent_scratchpad']), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), ) properties['agent'] = AgentEntity( @@ -551,7 +333,7 @@ def convert_from_app_model_config_dict(self, tenant_id: str, dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] ) ) @@ -624,169 +406,3 @@ def convert_from_app_model_config_dict(self, tenant_id: str, ) return AppOrchestrationConfigEntity(**properties) - - def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> tuple[Conversation, Message]: - """ - Initialize generate records - :param application_generate_entity: application generate entity - :return: - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - - model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_schema = model_type_instance.get_model_schema( - model=app_orchestration_config_entity.model_config.model, - credentials=app_orchestration_config_entity.model_config.credentials - ) - - app_record = (db.session.query(App) - .filter(App.id == application_generate_entity.app_id).first()) - - app_mode = app_record.mode - - # get from source - end_user_id = None - account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' - end_user_id = application_generate_entity.user_id - else: - from_source = 'console' - account_id = application_generate_entity.user_id - - override_model_configs = None - if application_generate_entity.app_model_config_override: - override_model_configs = application_generate_entity.app_model_config_dict - - introduction = '' - if app_mode == 'chat': - # get conversation introduction - introduction = self._get_conversation_introduction(application_generate_entity) - - if not application_generate_entity.conversation_id: - conversation = Conversation( - app_id=app_record.id, - app_model_config_id=application_generate_entity.app_model_config_id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_mode, - name='New conversation', - inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status='normal', - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - ) - - db.session.add(conversation) - db.session.commit() - db.session.refresh(conversation) - else: - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == application_generate_entity.conversation_id, - Conversation.app_id == app_record.id - ).first() - ) - - currency = model_schema.pricing.currency if model_schema.pricing else 'USD' - - message = Message( - app_id=app_record.id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query or "", - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency=currency, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - agent_based=app_orchestration_config_entity.agent is not None - ) - - db.session.add(message) - db.session.commit() - db.session.refresh(message) - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, - upload_file_id=file.upload_file_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message - - def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: - """ - Get conversation introduction - :param application_generate_entity: application generate entity - :return: conversation introduction - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - introduction = app_orchestration_config_entity.opening_statement - - if introduction: - try: - inputs = application_generate_entity.inputs - prompt_template = PromptTemplateParser(template=introduction) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - introduction = prompt_template.format(prompt_inputs) - except KeyError: - pass - - return introduction - - def _get_conversation(self, conversation_id: str) -> Conversation: - """ - Get conversation by conversation id - :param conversation_id: conversation id - :return: conversation - """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) - - return conversation - - def _get_message(self, message_id: str) -> Message: - """ - Get message by message id - :param message_id: message id - :return: message - """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) - - return message diff --git a/api/core/application_queue_manager.py b/api/core/app/app_queue_manager.py similarity index 97% rename from api/core/application_queue_manager.py rename to api/core/app/app_queue_manager.py index 9590a1e7266f2f..c09cae3245ce1b 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -32,7 +32,7 @@ class PublishFrom(Enum): TASK_PIPELINE = 2 -class ApplicationQueueManager: +class AppQueueManager: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, @@ -50,7 +50,7 @@ def __init__(self, task_id: str, self._message_id = str(message_id) user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") q = queue.Queue() @@ -239,7 +239,7 @@ def _is_stopped(self) -> bool: Check if task is stopped :return: """ - stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) result = redis_client.get(stopped_cache_key) if result is not None: return True diff --git a/api/core/app_runner/app_runner.py b/api/core/app/base_app_runner.py similarity index 94% rename from api/core/app_runner/app_runner.py rename to api/core/app/base_app_runner.py index 95f2f568dcfad2..788e3f91a3fcaf 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app/base_app_runner.py @@ -2,7 +2,7 @@ from collections.abc import Generator from typing import Optional, Union, cast -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ( ApplicationGenerateEntity, AppOrchestrationConfigEntity, @@ -11,10 +11,10 @@ ModelConfigEntity, PromptTemplateEntity, ) -from core.features.annotation_reply import AnnotationReplyFeature -from core.features.external_data_fetch import ExternalDataFetchFeature -from core.features.hosting_moderation import HostingModerationFeature -from core.features.moderation import ModerationFeature +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature +from core.moderation.input_moderation import InputModeration from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -169,7 +169,7 @@ def organize_prompt_messages(self, app_record: App, return prompt_messages, stop - def direct_output(self, queue_manager: ApplicationQueueManager, + def direct_output(self, queue_manager: AppQueueManager, app_orchestration_config: AppOrchestrationConfigEntity, prompt_messages: list, text: str, @@ -210,7 +210,7 @@ def direct_output(self, queue_manager: ApplicationQueueManager, ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, stream: bool, agent: bool = False) -> None: """ @@ -234,7 +234,7 @@ def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], ) def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result direct @@ -248,7 +248,7 @@ def _handle_invoke_result_direct(self, invoke_result: LLMResult, ) def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result @@ -306,7 +306,7 @@ def moderation_for_inputs(self, app_id: str, :param query: query :return: """ - moderation_feature = ModerationFeature() + moderation_feature = InputModeration() return moderation_feature.check( app_id=app_id, tenant_id=tenant_id, @@ -316,7 +316,7 @@ def moderation_for_inputs(self, app_id: str, ) def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -358,7 +358,7 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str, :param query: the query :return: the filled inputs """ - external_data_fetch_feature = ExternalDataFetchFeature() + external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( tenant_id=tenant_id, app_id=app_id, @@ -388,4 +388,4 @@ def query_app_annotations_to_reply(self, app_record: App, query=query, user_id=user_id, invoke_from=invoke_from - ) \ No newline at end of file + ) diff --git a/api/core/features/__init__.py b/api/core/app/chat/__init__.py similarity index 100% rename from api/core/features/__init__.py rename to api/core/app/chat/__init__.py diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app/chat/app_runner.py similarity index 95% rename from api/core/app_runner/basic_app_runner.py rename to api/core/app/chat/app_runner.py index 0e0fe6e3bfa099..a1613e37a27c5d 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,7 +10,7 @@ InvokeFrom, ModelConfigEntity, ) -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException @@ -20,13 +20,13 @@ logger = logging.getLogger(__name__) -class BasicApplicationRunner(AppRunner): +class ChatAppRunner(AppRunner): """ - Basic Application Runner + Chat Application Runner """ def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -215,7 +215,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, def retrieve_dataset_context(self, tenant_id: str, app_record: App, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, model_config: ModelConfigEntity, dataset_config: DatasetEntity, show_retrieve_source: bool, @@ -254,7 +254,7 @@ def retrieve_dataset_context(self, tenant_id: str, and dataset_config.retrieve_config.query_variable): query = inputs.get(dataset_config.retrieve_config.query_variable, "") - dataset_retrieval = DatasetRetrievalFeature() + dataset_retrieval = DatasetRetrieval() return dataset_retrieval.retrieve( tenant_id=tenant_id, model_config=model_config, diff --git a/api/core/apps/app_config_validators/chat_app.py b/api/core/app/chat/config_validator.py similarity index 75% rename from api/core/apps/app_config_validators/chat_app.py rename to api/core/app/chat/config_validator.py index 83c792e610c386..adb8408e285013 100644 --- a/api/core/apps/app_config_validators/chat_app.py +++ b/api/core/app/chat/config_validator.py @@ -1,15 +1,15 @@ -from core.apps.config_validators.dataset import DatasetValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator from models.model import AppMode @@ -35,7 +35,7 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation diff --git a/api/core/features/dataset_retrieval/__init__.py b/api/core/app/completion/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/__init__.py rename to api/core/app/completion/__init__.py diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py new file mode 100644 index 00000000000000..34c6a5156f6eee --- /dev/null +++ b/api/core/app/completion/app_runner.py @@ -0,0 +1,266 @@ +import logging +from typing import Optional + +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.application_entities import ( + ApplicationGenerateEntity, + DatasetEntity, + InvokeFrom, + ModelConfigEntity, +) +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.moderation.base import ModerationException +from extensions.ext_database import db +from models.model import App, AppMode, Conversation, Message + +logger = logging.getLogger(__name__) + + +class CompletionAppRunner(AppRunner): + """ + Completion Application Runner + """ + + def run(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + if not app_record: + raise ValueError("App not found") + + app_orchestration_config = application_generate_entity.app_orchestration_config_entity + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query + ) + + memory = None + if application_generate_entity.conversation_id: + # get memory of conversation (read-only) + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=application_generate_entity.tenant_id, + app_orchestration_config_entity=app_orchestration_config, + inputs=inputs, + query=query, + ) + except ModerationException as e: + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_orchestration_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + + # get context from datasets + context = None + if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + context = self.retrieve_dataset_context( + tenant_id=app_record.tenant_id, + app_record=app_record, + queue_manager=queue_manager, + model_config=app_orchestration_config.model_config, + show_retrieve_source=app_orchestration_config.show_retrieve_source, + dataset_config=app_orchestration_config.dataset, + message=message, + inputs=inputs, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + memory=memory + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + context=context, + memory=memory + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recale_llm_max_tokens( + model_config=app_orchestration_config.model_config, + prompt_messages=prompt_messages + ) + + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_orchestration_config.model_config.parameters, + stop=stop, + stream=application_generate_entity.stream, + user=application_generate_entity.user_id, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream + ) + + def retrieve_dataset_context(self, tenant_id: str, + app_record: App, + queue_manager: AppQueueManager, + model_config: ModelConfigEntity, + dataset_config: DatasetEntity, + show_retrieve_source: bool, + message: Message, + inputs: dict, + query: str, + user_id: str, + invoke_from: InvokeFrom, + memory: Optional[TokenBufferMemory] = None) -> Optional[str]: + """ + Retrieve dataset context + :param tenant_id: tenant id + :param app_record: app record + :param queue_manager: queue manager + :param model_config: model config + :param dataset_config: dataset config + :param show_retrieve_source: show retrieve source + :param message: message + :param inputs: inputs + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :param memory: memory + :return: + """ + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + user_id, + invoke_from + ) + + # TODO + if (app_record.mode == AppMode.COMPLETION.value and dataset_config + and dataset_config.retrieve_config.query_variable): + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + return dataset_retrieval.retrieve( + tenant_id=tenant_id, + model_config=model_config, + config=dataset_config, + query=query, + invoke_from=invoke_from, + show_retrieve_source=show_retrieve_source, + hit_callback=hit_callback, + memory=memory + ) + \ No newline at end of file diff --git a/api/core/apps/app_config_validators/completion_app.py b/api/core/app/completion/config_validator.py similarity index 76% rename from api/core/apps/app_config_validators/completion_app.py rename to api/core/app/completion/config_validator.py index 00371f8d05fadd..7cc35efd64ac62 100644 --- a/api/core/apps/app_config_validators/completion_app.py +++ b/api/core/app/completion/config_validator.py @@ -1,12 +1,12 @@ -from core.apps.config_validators.dataset import DatasetValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.more_like_this import MoreLikeThisValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.more_like_this import MoreLikeThisValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator from models.model import AppMode @@ -32,7 +32,7 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation diff --git a/api/core/features/dataset_retrieval/agent/__init__.py b/api/core/app/features/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/__init__.py rename to api/core/app/features/__init__.py diff --git a/api/core/features/dataset_retrieval/agent/output_parser/__init__.py b/api/core/app/features/annotation_reply/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/__init__.py rename to api/core/app/features/annotation_reply/__init__.py diff --git a/api/core/features/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py similarity index 100% rename from api/core/features/annotation_reply.py rename to api/core/app/features/annotation_reply/annotation_reply.py diff --git a/api/core/app/features/hosting_moderation/__init__.py b/api/core/app/features/hosting_moderation/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py similarity index 100% rename from api/core/features/hosting_moderation.py rename to api/core/app/features/hosting_moderation/hosting_moderation.py diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py similarity index 98% rename from api/core/app_runner/generate_task_pipeline.py rename to api/core/app/generate_task_pipeline.py index 1cc56483ad3770..6d52fa7348ef95 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -6,8 +6,8 @@ from pydantic import BaseModel -from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import ( AnnotationReplyEvent, @@ -35,7 +35,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db @@ -59,7 +59,7 @@ class GenerateTaskPipeline: """ def __init__(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -633,7 +633,7 @@ def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMess return prompts - def _init_output_moderation(self) -> Optional[OutputModerationHandler]: + def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. :return: @@ -642,7 +642,7 @@ def _init_output_moderation(self) -> Optional[OutputModerationHandler]: sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance if sensitive_word_avoidance: - return OutputModerationHandler( + return OutputModeration( tenant_id=self._application_generate_entity.tenant_id, app_id=self._application_generate_entity.app_id, rule=ModerationRule( diff --git a/api/core/app/validators/__init__.py b/api/core/app/validators/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/apps/config_validators/dataset.py b/api/core/app/validators/dataset_retrieval.py similarity index 100% rename from api/core/apps/config_validators/dataset.py rename to api/core/app/validators/dataset_retrieval.py diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/app/validators/external_data_fetch.py similarity index 97% rename from api/core/apps/config_validators/external_data_tools.py rename to api/core/app/validators/external_data_fetch.py index 02ecc8d71598c5..5910aa17e76fd8 100644 --- a/api/core/apps/config_validators/external_data_tools.py +++ b/api/core/app/validators/external_data_fetch.py @@ -2,7 +2,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory -class ExternalDataToolsValidator: +class ExternalDataFetchValidator: @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/app/validators/file_upload.py similarity index 100% rename from api/core/apps/config_validators/file_upload.py rename to api/core/app/validators/file_upload.py diff --git a/api/core/apps/config_validators/model.py b/api/core/app/validators/model_validator.py similarity index 100% rename from api/core/apps/config_validators/model.py rename to api/core/app/validators/model_validator.py diff --git a/api/core/apps/config_validators/moderation.py b/api/core/app/validators/moderation.py similarity index 100% rename from api/core/apps/config_validators/moderation.py rename to api/core/app/validators/moderation.py diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/app/validators/more_like_this.py similarity index 100% rename from api/core/apps/config_validators/more_like_this.py rename to api/core/app/validators/more_like_this.py diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/app/validators/opening_statement.py similarity index 100% rename from api/core/apps/config_validators/opening_statement.py rename to api/core/app/validators/opening_statement.py diff --git a/api/core/apps/config_validators/prompt.py b/api/core/app/validators/prompt.py similarity index 100% rename from api/core/apps/config_validators/prompt.py rename to api/core/app/validators/prompt.py diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/app/validators/retriever_resource.py similarity index 100% rename from api/core/apps/config_validators/retriever_resource.py rename to api/core/app/validators/retriever_resource.py diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/app/validators/speech_to_text.py similarity index 100% rename from api/core/apps/config_validators/speech_to_text.py rename to api/core/app/validators/speech_to_text.py diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/app/validators/suggested_questions.py similarity index 100% rename from api/core/apps/config_validators/suggested_questions.py rename to api/core/app/validators/suggested_questions.py diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/app/validators/text_to_speech.py similarity index 100% rename from api/core/apps/config_validators/text_to_speech.py rename to api/core/app/validators/text_to_speech.py diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/app/validators/user_input_form.py similarity index 100% rename from api/core/apps/config_validators/user_input_form.py rename to api/core/app/validators/user_input_form.py diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/apps/app_config_validators/workflow_app.py b/api/core/app/workflow/config_validator.py similarity index 83% rename from api/core/apps/app_config_validators/workflow_app.py rename to api/core/app/workflow/config_validator.py index 545d3d79a330e5..b76eabaeb555e5 100644 --- a/api/core/apps/app_config_validators/workflow_app.py +++ b/api/core/app/workflow/config_validator.py @@ -1,6 +1,6 @@ -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.text_to_speech import TextToSpeechValidator class WorkflowAppConfigValidator: diff --git a/api/core/apps/app_config_validators/agent_chat_app.py b/api/core/apps/app_config_validators/agent_chat_app.py deleted file mode 100644 index d507fae685613b..00000000000000 --- a/api/core/apps/app_config_validators/agent_chat_app.py +++ /dev/null @@ -1,82 +0,0 @@ -from core.apps.config_validators.agent import AgentValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class AgentChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for agent chat app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.AGENT_CHAT - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # agent_mode - config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py deleted file mode 100644 index b445aedbf868af..00000000000000 --- a/api/core/apps/config_validators/agent.py +++ /dev/null @@ -1,81 +0,0 @@ -import uuid - -from core.apps.config_validators.dataset import DatasetValidator -from core.entities.agent_entities import PlanningStrategy - -OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] - - -class AgentValidator: - @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for agent feature - - :param tenant_id: tenant ID - :param config: app model config args - """ - if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } - - if not isinstance(config["agent_mode"], dict): - raise ValueError("agent_mode must be of object type") - - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False - - if not isinstance(config["agent_mode"]["enabled"], bool): - raise ValueError("enabled in agent_mode must be of boolean type") - - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - - if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: - raise ValueError("strategy in agent_mode must be in the specified strategy list") - - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] - - if not isinstance(config["agent_mode"]["tools"], list): - raise ValueError("tools in agent_mode must be a list of objects") - - for tool in config["agent_mode"]["tools"]: - key = list(tool.keys())[0] - if key in OLD_TOOLS: - # old style, use tool name as key - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False - - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") - - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") - - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") - else: - # latest style, use key-value pair - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - if "provider_type" not in tool: - raise ValueError("provider_type is required in agent_mode.tools") - if "provider_id" not in tool: - raise ValueError("provider_id is required in agent_mode.tools") - if "tool_name" not in tool: - raise ValueError("tool_name is required in agent_mode.tools") - if "tool_parameters" not in tool: - raise ValueError("tool_parameters is required in agent_mode.tools") - - return config, ["agent_mode"] diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index 1d25b8ab69d7e5..8a340a8b815111 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -7,7 +7,7 @@ from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult @@ -22,7 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): raise_error: bool = True def __init__(self, model_config: ModelConfigEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, message_chain: MessageChain) -> None: """Initialize callback handler.""" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 879c9df69dca54..e49a09d4c4b7ba 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,5 +1,5 @@ -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from core.rag.models.document import Document from extensions.ext_database import db @@ -10,7 +10,7 @@ class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: ApplicationQueueManager, + def __init__(self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, diff --git a/api/core/features/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py similarity index 98% rename from api/core/features/external_data_fetch.py rename to api/core/external_data_tool/external_data_fetch.py index ef37f055289cb4..64c7d1e859532f 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class ExternalDataFetchFeature: +class ExternalDataFetch: def fetch(self, tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index dd46aa27dc612c..01a8ea3a5d0c6c 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/llm_generator/__init__.py b/api/core/llm_generator/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/generator/llm_generator.py b/api/core/llm_generator/llm_generator.py similarity index 93% rename from api/core/generator/llm_generator.py rename to api/core/llm_generator/llm_generator.py index 072b02dc94638a..6ce70df70323bf 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -7,10 +7,10 @@ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser -from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser -from core.prompt.prompt_template import PromptTemplateParser -from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: diff --git a/api/core/llm_generator/output_parser/__init__.py b/api/core/llm_generator/output_parser/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py similarity index 94% rename from api/core/prompt/output_parser/rule_config_generator.py rename to api/core/llm_generator/output_parser/rule_config_generator.py index 619555ce2e99f8..b95653f69c6eea 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -2,7 +2,7 @@ from langchain.schema import BaseOutputParser, OutputParserException -from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE +from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py similarity index 87% rename from api/core/prompt/output_parser/suggested_questions_after_answer.py rename to api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e37142ec9146c0..ad30bcfa079b35 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -4,7 +4,7 @@ from langchain.schema import BaseOutputParser -from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompts.py b/api/core/llm_generator/prompts.py similarity index 100% rename from api/core/prompt/prompts.py rename to api/core/llm_generator/prompts.py diff --git a/api/core/features/moderation.py b/api/core/moderation/input_moderation.py similarity index 98% rename from api/core/features/moderation.py rename to api/core/moderation/input_moderation.py index a9d65f56e85c70..2129c58d8d2923 100644 --- a/api/core/features/moderation.py +++ b/api/core/moderation/input_moderation.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -class ModerationFeature: +class InputModeration: def check(self, app_id: str, tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, diff --git a/api/core/app_runner/moderation_handler.py b/api/core/moderation/output_moderation.py similarity index 97% rename from api/core/app_runner/moderation_handler.py rename to api/core/moderation/output_moderation.py index b2098344c843ac..749ee431e8f319 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/moderation/output_moderation.py @@ -6,7 +6,7 @@ from flask import Flask, current_app from pydantic import BaseModel -from core.application_queue_manager import PublishFrom +from core.app.app_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -18,7 +18,7 @@ class ModerationRule(BaseModel): config: dict[str, Any] -class OutputModerationHandler(BaseModel): +class OutputModeration(BaseModel): DEFAULT_BUFFER_SIZE: int = 300 tenant_id: str diff --git a/api/core/prompt/__init__.py b/api/core/prompt/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 7519971ce75292..617845392041fa 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,7 +15,7 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode diff --git a/api/core/prompt/prompt_templates/__init__.py b/api/core/prompt/prompt_templates/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py similarity index 100% rename from api/core/prompt/advanced_prompt_templates.py rename to api/core/prompt/prompt_templates/advanced_prompt_templates.py diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/prompt_templates/baichuan_chat.json similarity index 100% rename from api/core/prompt/generate_prompts/baichuan_chat.json rename to api/core/prompt/prompt_templates/baichuan_chat.json diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/prompt_templates/baichuan_completion.json similarity index 100% rename from api/core/prompt/generate_prompts/baichuan_completion.json rename to api/core/prompt/prompt_templates/baichuan_completion.json diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/prompt_templates/common_chat.json similarity index 100% rename from api/core/prompt/generate_prompts/common_chat.json rename to api/core/prompt/prompt_templates/common_chat.json diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/prompt_templates/common_completion.json similarity index 100% rename from api/core/prompt/generate_prompts/common_completion.json rename to api/core/prompt/prompt_templates/common_completion.json diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fcae0dc78643bb..f3a03b01c72215 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -15,7 +15,7 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from models.model import AppMode @@ -275,7 +275,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') # Open the JSON file and read its content diff --git a/api/core/prompt/utils/__init__.py b/api/core/prompt/utils/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/utils/prompt_template_parser.py similarity index 100% rename from api/core/prompt/prompt_template.py rename to api/core/prompt/utils/prompt_template_parser.py diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0d81c419d67ab1..139bfe15f328d6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -9,7 +9,7 @@ from flask import Flask, current_app from werkzeug.datastructures import FileStorage -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/core/rag/retrieval/__init__.py b/api/core/rag/retrieval/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/retrieval/agent/__init__.py b/api/core/rag/retrieval/agent/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/dataset_retrieval/agent/agent_llm_callback.py b/api/core/rag/retrieval/agent/agent_llm_callback.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/agent_llm_callback.py rename to api/core/rag/retrieval/agent/agent_llm_callback.py diff --git a/api/core/features/dataset_retrieval/agent/fake_llm.py b/api/core/rag/retrieval/agent/fake_llm.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/fake_llm.py rename to api/core/rag/retrieval/agent/fake_llm.py diff --git a/api/core/features/dataset_retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py similarity index 91% rename from api/core/features/dataset_retrieval/agent/llm_chain.py rename to api/core/rag/retrieval/agent/llm_chain.py index e5155e15a0849f..d07ee0a582c2c7 100644 --- a/api/core/features/dataset_retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -7,8 +7,8 @@ from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM +from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance diff --git a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py similarity index 98% rename from api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 59923202fde47a..8cc2e297438777 100644 --- a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM +from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool diff --git a/api/core/rag/retrieval/agent/output_parser/__init__.py b/api/core/rag/retrieval/agent/output_parser/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py b/api/core/rag/retrieval/agent/output_parser/structured_chat.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py rename to api/core/rag/retrieval/agent/output_parser/structured_chat.py diff --git a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py similarity index 99% rename from api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index e69302bfd682ea..4d7d33038bd8be 100644 --- a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -13,7 +13,7 @@ from langchain.tools import BaseTool from core.entities.application_entities import ModelConfigEntity -from core.features.dataset_retrieval.agent.llm_chain import LLMChain +from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py similarity index 92% rename from api/core/features/dataset_retrieval/agent_based_dataset_executor.py rename to api/core/rag/retrieval/agent_based_dataset_executor.py index 588ccc91f5f088..f1ccf986e9a0e9 100644 --- a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -10,10 +10,10 @@ from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent +from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/features/dataset_retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py similarity index 98% rename from api/core/features/dataset_retrieval/dataset_retrieval.py rename to api/core/rag/retrieval/dataset_retrieval.py index 3e54d8644d3b46..07682389d698da 100644 --- a/api/core/features/dataset_retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -5,7 +5,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor +from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -15,7 +15,7 @@ from models.dataset import Dataset -class DatasetRetrievalFeature: +class DatasetRetrieval: def retrieve(self, tenant_id: str, model_config: ModelConfigEntity, config: DatasetEntity, diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 30128c4dcaeea8..629ed2361341b8 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -4,7 +4,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.tool.tool import Tool @@ -30,7 +30,7 @@ def get_dataset_tools(tenant_id: str, if retrieve_config is None: return [] - feature = DatasetRetrievalFeature() + feature = DatasetRetrieval() # save original retrieve strategy, and set retrieve strategy to SINGLE # Agent only support SINGLE mode diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 74dc8d5112fa5c..f5f3ba2540d7aa 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,4 +1,4 @@ -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db diff --git a/api/models/model.py b/api/models/model.py index 8d286d34827f71..235f77abc32152 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -310,22 +310,28 @@ def to_dict(self) -> dict: def from_model_config_dict(self, model_config: dict): self.opening_statement = model_config['opening_statement'] - self.suggested_questions = json.dumps(model_config['suggested_questions']) - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) + self.suggested_questions = json.dumps(model_config['suggested_questions']) \ + if model_config.get('suggested_questions') else None + self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ + if model_config.get('suggested_questions_after_answer') else None self.speech_to_text = json.dumps(model_config['speech_to_text']) \ if model_config.get('speech_to_text') else None self.text_to_speech = json.dumps(model_config['text_to_speech']) \ if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) + self.more_like_this = json.dumps(model_config['more_like_this']) \ + if model_config.get('more_like_this') else None self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ if model_config.get('sensitive_word_avoidance') else None self.external_data_tools = json.dumps(model_config['external_data_tools']) \ if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) - self.user_input_form = json.dumps(model_config['user_input_form']) + self.model = json.dumps(model_config['model']) \ + if model_config.get('model') else None + self.user_input_form = json.dumps(model_config['user_input_form']) \ + if model_config.get('user_input_form') else None self.dataset_query_variable = model_config.get('dataset_query_variable') self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) + self.agent_mode = json.dumps(model_config['agent_mode']) \ + if model_config.get('agent_mode') else None self.retriever_resource = json.dumps(model_config['retriever_resource']) \ if model_config.get('retriever_resource') else None self.prompt_type = model_config.get('prompt_type', 'simple') diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 1e893e0eca4cad..213df262223d8a 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,7 @@ import copy -from core.prompt.advanced_prompt_templates import ( +from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c1e0ecebe82a1f..789d74ed2cc13f 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,8 +1,8 @@ -from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator -from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator -from core.apps.app_config_validators.chat_app import ChatAppConfigValidator -from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator -from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator +from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator +from core.app.agent_chat.config_validator import AgentChatAppConfigValidator +from core.app.chat.config_validator import ChatAppConfigValidator +from core.app.completion.config_validator import CompletionAppConfigValidator +from core.app.workflow.config_validator import WorkflowAppConfigValidator from models.model import AppMode diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 9acd62b997f044..8a9639e521aded 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -4,8 +4,8 @@ from sqlalchemy import and_ -from core.application_manager import ApplicationManager -from core.apps.config_validators.model import ModelValidator +from core.app.app_manager import AppManager +from core.app.validators.model_validator import ModelValidator from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db @@ -137,7 +137,7 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, user ) - application_manager = ApplicationManager() + application_manager = AppManager() return application_manager.generate( tenant_id=app_model.tenant_id, app_id=app_model.id, @@ -193,7 +193,7 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message.files, app_model_config ) - application_manager = ApplicationManager() + application_manager = AppManager() return application_manager.generate( tenant_id=app_model.tenant_id, app_id=app_model.id, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ac3df380b265a7..1a0213799e619a 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account diff --git a/api/services/message_service.py b/api/services/message_service.py index ad2ff60f6b83c7..20918a8781bed3 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index fb6cf1fd5a4fb5..f384855e7a1d24 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,7 +1,7 @@ import json from typing import Optional -from core.application_manager import ApplicationManager +from core.app.app_manager import AppManager from core.entities.application_entities import ( DatasetEntity, DatasetRetrieveConfigEntity, @@ -111,7 +111,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, new_app_mode = self._get_new_app_mode(app_model) # convert app model config - application_manager = ApplicationManager() + application_manager = AppManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, app_model_config_dict=app_model_config.to_dict(), diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 95f1e30b449c72..69acb23681f17f 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -8,7 +8,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation From 9467fe9aa9f14a111816abc739fdecfd7c043d84 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 17:34:18 +0800 Subject: [PATCH 045/160] lint fix --- api/core/agent/base_agent_runner.py | 2 +- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- api/core/app/agent_chat/app_runner.py | 6 ++--- api/core/app/agent_chat/config_validator.py | 3 +-- api/core/app/app_manager.py | 4 ++-- .../app/app_orchestration_config_converter.py | 23 +++++++++++++++---- api/core/app/base_app_runner.py | 6 ++--- api/core/app/chat/app_runner.py | 4 ++-- api/core/app/completion/app_runner.py | 4 ++-- api/core/app/generate_task_pipeline.py | 2 +- api/core/llm_generator/llm_generator.py | 6 ++--- .../suggested_questions_after_answer.py | 1 + api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/prompt/simple_prompt_transform.py | 2 +- api/core/rag/retrieval/agent/llm_chain.py | 2 +- .../agent/multi_dataset_router_agent.py | 2 +- .../retrieval/agent_based_dataset_executor.py | 6 ++--- api/core/rag/retrieval/dataset_retrieval.py | 2 +- 19 files changed, 47 insertions(+), 34 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0658124d1422ff..1474c6a4757bf5 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,8 +5,8 @@ from mimetypes import guess_extension from typing import Optional, Union, cast -from core.app.base_app_runner import AppRunner from core.app.app_queue_manager import AppQueueManager +from core.app.base_app_runner import AppRunner from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 152e4457955fae..5650113f470c2f 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -3,9 +3,9 @@ from collections.abc import Generator from typing import Literal, Union +from core.agent.base_agent_runner import BaseAgentRunner from core.app.app_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit -from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 0cf0d3762cea43..9b238bf2324b84 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -3,8 +3,8 @@ from collections.abc import Generator from typing import Any, Union -from core.app.app_queue_manager import PublishFrom from core.agent.base_agent_runner import BaseAgentRunner +from core.app.app_queue_manager import PublishFrom from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/app/agent_chat/app_runner.py b/api/core/app/agent_chat/app_runner.py index b046e935a52a8a..38789348ad749d 100644 --- a/api/core/app/agent_chat/app_runner.py +++ b/api/core/app/agent_chat/app_runner.py @@ -1,11 +1,11 @@ import logging from typing import cast -from core.app.base_app_runner import AppRunner -from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity from core.agent.cot_agent_runner import CotAgentRunner from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/agent_chat/config_validator.py index 6596b19f9945e7..82bc40bd9bfd7f 100644 --- a/api/core/app/agent_chat/config_validator.py +++ b/api/core/app/agent_chat/config_validator.py @@ -1,6 +1,5 @@ import uuid -from core.entities.agent_entities import PlanningStrategy from core.app.validators.dataset_retrieval import DatasetValidator from core.app.validators.external_data_fetch import ExternalDataFetchValidator from core.app.validators.file_upload import FileUploadValidator @@ -13,9 +12,9 @@ from core.app.validators.suggested_questions import SuggestedQuestionsValidator from core.app.validators.text_to_speech import TextToSpeechValidator from core.app.validators.user_input_form import UserInputFormValidator +from core.entities.agent_entities import PlanningStrategy from models.model import AppMode - OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py index 0819ed864baab0..86c8d2cfc735b2 100644 --- a/api/core/app/app_manager.py +++ b/api/core/app/app_manager.py @@ -8,11 +8,11 @@ from flask import Flask, current_app from pydantic import ValidationError -from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter from core.app.agent_chat.app_runner import AgentChatAppRunner +from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.chat.app_runner import ChatAppRunner from core.app.generate_task_pipeline import GenerateTaskPipeline -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.entities.application_entities import ( ApplicationGenerateEntity, InvokeFrom, diff --git a/api/core/app/app_orchestration_config_converter.py b/api/core/app/app_orchestration_config_converter.py index ddf49949a3bfb5..1d429ee6d91a46 100644 --- a/api/core/app/app_orchestration_config_converter.py +++ b/api/core/app/app_orchestration_config_converter.py @@ -1,11 +1,24 @@ from typing import cast -from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \ - TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \ - ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \ - AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity +from core.entities.application_entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + AgentEntity, + AgentPromptEntity, + AgentToolEntity, + AppOrchestrationConfigEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + SensitiveWordAvoidanceEntity, + TextToSpeechEntity, + VariableEntity, +) from core.entities.model_entities import ModelStatus -from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/app/base_app_runner.py b/api/core/app/base_app_runner.py index 788e3f91a3fcaf..2760d04180a00c 100644 --- a/api/core/app/base_app_runner.py +++ b/api/core/app/base_app_runner.py @@ -3,6 +3,8 @@ from typing import Optional, Union, cast from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.entities.application_entities import ( ApplicationGenerateEntity, AppOrchestrationConfigEntity, @@ -11,10 +13,7 @@ ModelConfigEntity, PromptTemplateEntity, ) -from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature -from core.moderation.input_moderation import InputModeration from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,6 +21,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation diff --git a/api/core/app/chat/app_runner.py b/api/core/app/chat/app_runner.py index a1613e37a27c5d..a1eccab13a66da 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app.base_app_runner import AppRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,10 +10,10 @@ InvokeFrom, ModelConfigEntity, ) -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, AppMode, Conversation, Message diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py index 34c6a5156f6eee..3ac182b34e4e31 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/completion/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app.base_app_runner import AppRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,10 +10,10 @@ InvokeFrom, ModelConfigEntity, ) -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, AppMode, Conversation, Message diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index 6d52fa7348ef95..dc6ea2db7924c5 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -6,7 +6,6 @@ from pydantic import BaseModel -from core.moderation.output_moderation import ModerationRule, OutputModeration from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import ( @@ -35,6 +34,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.moderation.output_moderation import ModerationRule, OutputModeration from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6ce70df70323bf..1a6b71fb0ad20f 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -3,14 +3,14 @@ from langchain.schema import OutputParserException +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser -from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index ad30bcfa079b35..1b955c6edd2442 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -5,6 +5,7 @@ from langchain.schema import BaseOutputParser from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.model_runtime.errors.invoke import InvokeError class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 617845392041fa..6d0a1d31f5ad8e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,9 +15,9 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index f3a03b01c72215..af7b695bb33718 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -15,8 +15,8 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py index d07ee0a582c2c7..087b7bfa2c34e3 100644 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -7,9 +7,9 @@ from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.fake_llm import FakeLLM -from core.model_manager import ModelInstance class LLMChain(LCLLMChain): diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 8cc2e297438777..41a0c54041f8e9 100644 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -12,9 +12,9 @@ from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool +from core.rag.retrieval.agent.fake_llm import FakeLLM class MultiDatasetRouterAgent(OpenAIFunctionsAgent): diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py index f1ccf986e9a0e9..7fabf71bedc3a1 100644 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -10,13 +10,13 @@ from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages +from core.helper import moderation +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.errors.invoke import InvokeError from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent -from core.helper import moderation -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.errors.invoke import InvokeError from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 07682389d698da..21e16c4162171d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -5,10 +5,10 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity -from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db From 8a8882ed8d09882de1c02a55c4b35bdf0eee9dcd Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:03:03 +0800 Subject: [PATCH 046/160] move workflow_id to app --- api/constants/model_template.py | 11 +- api/controllers/console/app/workflow.py | 8 +- api/core/app/chat/app_runner.py | 81 ++--------- api/core/app/completion/app_runner.py | 134 +++--------------- api/fields/workflow_fields.py | 5 +- .../versions/b289e2408ee2_add_workflow.py | 5 +- api/models/model.py | 22 ++- api/models/workflow.py | 10 ++ api/services/app_service.py | 96 ++++++++----- api/services/workflow/workflow_converter.py | 54 ++++--- api/services/workflow_service.py | 39 ++--- 11 files changed, 166 insertions(+), 299 deletions(-) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 61aab64d8a298e..c8aaba23cb83ed 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,8 +7,7 @@ 'mode': AppMode.WORKFLOW.value, 'enable_site': True, 'enable_api': True - }, - 'model_config': {} + } }, # chat default mode @@ -34,14 +33,6 @@ 'mode': AppMode.ADVANCED_CHAT.value, 'enable_site': True, 'enable_api': True - }, - 'model_config': { - 'model': { - "provider": "openai", - "name": "gpt-4", - "mode": "chat", - "completion_params": {} - } } }, diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4fcf8daf6ec64e..54585d8519a9a3 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -41,10 +41,16 @@ def post(self, app_model: App): """ parser = reqparse.RequestParser() parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') + parser.add_argument('features', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() workflow_service = WorkflowService() - workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) + workflow_service.sync_draft_workflow( + app_model=app_model, + graph=args.get('graph'), + features=args.get('features'), + account=current_user + ) return { "result": "success" diff --git a/api/core/app/chat/app_runner.py b/api/core/app/chat/app_runner.py index a1eccab13a66da..4c8018572e69b2 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,21 +1,17 @@ import logging -from typing import Optional from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - DatasetEntity, - InvokeFrom, - ModelConfigEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from models.model import App, AppMode, Conversation, Message +from models.model import App, Conversation, Message logger = logging.getLogger(__name__) @@ -145,18 +141,23 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # get context from datasets context = None if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + config=app_orchestration_config.dataset, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_orchestration_config.show_retrieve_source, + hit_callback=hit_callback, memory=memory ) @@ -212,57 +213,3 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: AppQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - # TODO - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrieval() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) - \ No newline at end of file diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py index 3ac182b34e4e31..ab2f40ad9a877a 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/completion/app_runner.py @@ -1,21 +1,16 @@ import logging -from typing import Optional -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - DatasetEntity, - InvokeFrom, - ModelConfigEntity, ) -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from models.model import App, AppMode, Conversation, Message +from models.model import App, Message logger = logging.getLogger(__name__) @@ -27,13 +22,11 @@ class CompletionAppRunner(AppRunner): def run(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: AppQueueManager, - conversation: Conversation, message: Message) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager - :param conversation: conversation :param message: message :return: """ @@ -61,30 +54,15 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, query=query ) - memory = None - if application_generate_entity.conversation_id: - # get memory of conversation (read-only) - model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model - ) - - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) - # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) - # memory(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, model_config=app_orchestration_config.model_config, prompt_template_entity=app_orchestration_config.prompt_template, inputs=inputs, files=files, - query=query, - memory=memory + query=query ) # moderation @@ -107,30 +85,6 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, ) return - if query: - # annotation reply - annotation_reply = self.query_app_annotations_to_reply( - app_record=app_record, - message=message, - query=query, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from - ) - - if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER - ) - self.direct_output( - queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, - prompt_messages=prompt_messages, - text=annotation_reply.content, - stream=application_generate_entity.stream - ) - return - # fill in variable inputs from external data tools if exists external_data_tools = app_orchestration_config.external_data_variables if external_data_tools: @@ -145,19 +99,27 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # get context from datasets context = None if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_config = app_orchestration_config.dataset + if dataset_config and dataset_config.retrieve_config.query_variable: + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + config=dataset_config, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - memory=memory + show_retrieve_source=app_orchestration_config.show_retrieve_source, + hit_callback=hit_callback ) # reorganize all inputs and template to prompt messages @@ -170,8 +132,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, inputs=inputs, files=files, query=query, - context=context, - memory=memory + context=context ) # check hosting moderation @@ -210,57 +171,4 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: AppQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - # TODO - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrieval() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) \ No newline at end of file diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index decdc0567f1934..bcb2c318c6a8f9 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,5 +1,3 @@ -import json - from flask_restful import fields from fields.member_fields import simple_account_fields @@ -7,7 +5,8 @@ workflow_fields = { 'id': fields.String, - 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), + 'graph': fields.Nested(simple_account_fields, attribute='graph_dict'), + 'features': fields.Nested(simple_account_fields, attribute='features_dict'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 5f7ddc7d688f28..5ae1e65611ab3b 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -97,6 +97,7 @@ def upgrade(): sa.Column('type', sa.String(length=255), nullable=False), sa.Column('version', sa.String(length=255), nullable=False), sa.Column('graph', sa.Text(), nullable=True), + sa.Column('features', sa.Text(), nullable=True), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('updated_by', postgresql.UUID(), nullable=True), @@ -106,7 +107,7 @@ def upgrade(): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) with op.batch_alter_table('messages', schema=None) as batch_op: @@ -120,7 +121,7 @@ def downgrade(): with op.batch_alter_table('messages', schema=None) as batch_op: batch_op.drop_column('workflow_run_id') - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('workflow_id') with op.batch_alter_table('workflows', schema=None) as batch_op: diff --git a/api/models/model.py b/api/models/model.py index 235f77abc32152..c6409c61edd9c2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -63,6 +63,7 @@ class App(db.Model): icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(UUID, nullable=True) + workflow_id = db.Column(UUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) @@ -85,6 +86,14 @@ def app_model_config(self) -> Optional['AppModelConfig']: AppModelConfig.id == self.app_model_config_id).first() return app_model_config + @property + def workflow(self): + if self.workflow_id: + from api.models.workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None + @property def api_base_url(self): return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] @@ -176,7 +185,6 @@ class AppModelConfig(db.Model): dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) - workflow_id = db.Column(UUID) @property def app(self): @@ -276,14 +284,6 @@ def file_upload_dict(self) -> dict: "image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} - @property - def workflow(self): - if self.workflow_id: - from api.models.workflow import Workflow - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() - - return None - def to_dict(self) -> dict: return { "opening_statement": self.opening_statement, @@ -343,7 +343,6 @@ def from_model_config_dict(self, model_config: dict): if model_config.get('dataset_configs') else None self.file_upload = json.dumps(model_config.get('file_upload')) \ if model_config.get('file_upload') else None - self.workflow_id = model_config.get('workflow_id') return self def copy(self): @@ -368,8 +367,7 @@ def copy(self): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload, - workflow_id=self.workflow_id + file_upload=self.file_upload ) return new_app_model_config diff --git a/api/models/workflow.py b/api/models/workflow.py index 316d3e623e2af2..c38c1dd61079e4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,3 +1,4 @@ +import json from enum import Enum from typing import Union @@ -106,6 +107,7 @@ class Workflow(db.Model): type = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) + features = db.Column(db.Text) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_by = db.Column(UUID) @@ -119,6 +121,14 @@ def created_by_account(self): def updated_by_account(self): return Account.query.get(self.updated_by) + @property + def graph_dict(self): + return self.graph if not self.graph else json.loads(self.graph) + + @property + def features_dict(self): + return self.features if not self.features else json.loads(self.features) + class WorkflowRunTriggeredFrom(Enum): """ diff --git a/api/services/app_service.py b/api/services/app_service.py index 374727d2d42495..7dd5d770eaa1bd 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -64,8 +64,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: + default_model_config = app_template.get('model_config') + if default_model_config and 'model' in default_model_config: # get model provider model_manager = ModelManager() @@ -110,12 +110,15 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: db.session.add(app) db.session.flush() - app_model_config = AppModelConfig(**default_model_config) - app_model_config.app_id = app.id - db.session.add(app_model_config) - db.session.flush() + if default_model_config: + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() + + app.app_model_config_id = app_model_config.id - app.app_model_config_id = app_model_config.id + db.session.commit() app_was_created.send(app, account=account) @@ -135,16 +138,22 @@ def import_app(self, tenant_id: str, args: dict, account: Account) -> App: app_data = import_data.get('app') model_config_data = import_data.get('model_config') - workflow_graph = import_data.get('workflow_graph') + workflow = import_data.get('workflow') - if not app_data or not model_config_data: - raise ValueError("Missing app or model_config in data argument") + if not app_data: + raise ValueError("Missing app in data argument") app_mode = AppMode.value_of(app_data.get('mode')) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - if not workflow_graph: - raise ValueError("Missing workflow_graph in data argument " - "when mode is advanced-chat or workflow") + if not workflow: + raise ValueError("Missing workflow in data argument " + "when app mode is advanced-chat or workflow") + elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT]: + if not model_config_data: + raise ValueError("Missing model_config in data argument " + "when app mode is chat or agent-chat") + else: + raise ValueError("Invalid app mode") app = App( tenant_id=tenant_id, @@ -161,26 +170,32 @@ def import_app(self, tenant_id: str, args: dict, account: Account) -> App: db.session.add(app) db.session.commit() - if workflow_graph: + app_was_created.send(app, account=account) + + if workflow: # init draft workflow workflow_service = WorkflowService() - workflow_service.sync_draft_workflow(app, workflow_graph, account) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow.get('graph'), + features=workflow.get('features'), + account=account + ) - db.session.add(app_model_config) - db.session.commit() + if model_config_data: + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id - app.app_model_config_id = app_model_config.id + db.session.add(app_model_config) + db.session.commit() - app_was_created.send(app, account=account) + app.app_model_config_id = app_model_config.id - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + app_model_config_was_updated.send( + app, + app_model_config=app_model_config + ) return app @@ -190,7 +205,7 @@ def export_app(self, app: App) -> str: :param app: App instance :return: """ - app_model_config = app.app_model_config + app_mode = AppMode.value_of(app.mode) export_data = { "app": { @@ -198,16 +213,27 @@ def export_app(self, app: App) -> str: "mode": app.mode, "icon": app.icon, "icon_background": app.icon_background - }, - "model_config": app_model_config.to_dict(), + } } - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app.workflow_id: + workflow = app.workflow + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } + else: + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app) + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } else: - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_model_config = app.app_model_config + + export_data['model_config'] = app_model_config.to_dict() return yaml.dump(export_data) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f384855e7a1d24..6c0182dd9e5bfd 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -44,13 +44,10 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: :param account: Account :return: new App instance """ - # get original app config - app_model_config = app_model.app_model_config - # convert app model config workflow = self.convert_app_model_config_to_workflow( app_model=app_model, - app_model_config=app_model_config, + app_model_config=app_model.app_model_config, account_id=account.id ) @@ -58,8 +55,9 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = app_model.name + '(workflow)' - new_app.mode = AppMode.CHAT.value \ + new_app.mode = AppMode.ADVANCED_CHAT.value \ if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.workflow_id = workflow.id new_app.icon = app_model.icon new_app.icon_background = app_model.icon_background new_app.enable_site = app_model.enable_site @@ -69,28 +67,6 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: new_app.is_demo = False new_app.is_public = app_model.is_public db.session.add(new_app) - db.session.flush() - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.app_id = new_app.id - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.flush() - - new_app.app_model_config_id = new_app_model_config.id db.session.commit() app_was_created.send(new_app, account=account) @@ -110,11 +86,13 @@ def convert_app_model_config_to_workflow(self, app_model: App, # get new app mode new_app_mode = self._get_new_app_mode(app_model) + app_model_config_dict = app_model_config.to_dict() + # convert app model config application_manager = AppManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config.to_dict(), + app_model_config_dict=app_model_config_dict, skip_check=True ) @@ -177,6 +155,25 @@ def convert_app_model_config_to_workflow(self, app_model: App, graph = self._append_node(graph, end_node) + # features + if new_app_mode == AppMode.ADVANCED_CHAT: + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + else: + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, @@ -184,6 +181,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), + features=json.dumps(features), created_by=account_id, created_at=app_model_config.created_at ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 5a9234c70a4c26..006bc44e41d639 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -33,29 +33,31 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ Get published workflow """ - app_model_config = app_model.app_model_config - - if not app_model_config.workflow_id: + if not app_model.workflow_id: return None # fetch published workflow by workflow_id workflow = db.session.query(Workflow).filter( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.id == app_model_config.workflow_id + Workflow.id == app_model.workflow_id ).first() # return published workflow return workflow - - def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: + def sync_draft_workflow(self, app_model: App, + graph: dict, + features: dict, + account: Account) -> Workflow: """ Sync draft workflow """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) + # TODO validate features + # create draft workflow if not found if not workflow: workflow = Workflow( @@ -64,12 +66,14 @@ def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> type=WorkflowType.from_app_mode(app_model.mode).value, version='draft', graph=json.dumps(graph), + features=json.dumps(features), created_by=account.id ) db.session.add(workflow) # update draft workflow if found else: workflow.graph = json.dumps(graph) + workflow.features = json.dumps(features) workflow.updated_by = account.id workflow.updated_at = datetime.utcnow() @@ -112,28 +116,7 @@ def publish_workflow(self, app_model: App, db.session.add(workflow) db.session.commit() - app_model_config = app_model.app_model_config - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.app_id = app_model.id - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.flush() - - app_model.app_model_config_id = new_app_model_config.id + app_model.workflow_id = workflow.id db.session.commit() # TODO update app related datasets From 7bff65304fd4e672e95ccacf700a85c6d9070497 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:20:27 +0800 Subject: [PATCH 047/160] add features structure validate --- api/controllers/console/app/model_config.py | 36 +------------------ .../app/advanced_chat/config_validator.py | 9 +++-- api/core/app/validators/moderation.py | 18 +++++----- api/core/app/workflow/config_validator.py | 9 +++-- api/services/app_model_config_service.py | 9 ----- api/services/workflow_service.py | 26 ++++++++++++-- 6 files changed, 49 insertions(+), 58 deletions(-) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index d822f859bc35e8..1301d12da4e560 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -2,7 +2,7 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restful import Resource from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -137,38 +137,4 @@ def post(self, app_model): return {'result': 'success'} -class FeaturesResource(Resource): - - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def put(self, app_model): - """Get app features""" - parser = reqparse.RequestParser() - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - args = parser.parse_args() - - model_configuration = AppModelConfigService.validate_features( - tenant_id=current_user.current_tenant_id, - config=args.get('features'), - app_mode=AppMode.value_of(app_model.mode) - ) - - # update config - app_model_config = app_model.app_model_config - app_model_config.from_model_config_dict(model_configuration) - db.session.commit() - - app_model_config_was_updated.send( - app_model, - app_model_config=app_model_config - ) - - return { - 'result': 'success' - } - - api.add_resource(ModelConfigResource, '/apps//model-config') -api.add_resource(FeaturesResource, '/apps//features') diff --git a/api/core/app/advanced_chat/config_validator.py b/api/core/app/advanced_chat/config_validator.py index 39c00c028ef2ee..a20198ef4a344d 100644 --- a/api/core/app/advanced_chat/config_validator.py +++ b/api/core/app/advanced_chat/config_validator.py @@ -9,12 +9,13 @@ class AdvancedChatAppConfigValidator: @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: """ Validate for advanced chat app model config :param tenant_id: tenant id :param config: app model config args + :param only_structure_validate: if True, only structure validation will be performed """ related_config_keys = [] @@ -43,7 +44,11 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/validators/moderation.py b/api/core/app/validators/moderation.py index 4813385588905b..7a5dff55c9e6e0 100644 --- a/api/core/app/validators/moderation.py +++ b/api/core/app/validators/moderation.py @@ -7,7 +7,8 @@ class ModerationValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ + -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = { "enabled": False @@ -23,13 +24,14 @@ def validate_and_set_defaults(cls, tenant_id, config: dict) -> tuple[dict, list[ if not config["sensitive_word_avoidance"].get("type"): raise ValueError("sensitive_word_avoidance.type is required") - typ = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] + if not only_structure_validate: + typ = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ModerationFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/workflow/config_validator.py b/api/core/app/workflow/config_validator.py index b76eabaeb555e5..e8381146a75501 100644 --- a/api/core/app/workflow/config_validator.py +++ b/api/core/app/workflow/config_validator.py @@ -5,12 +5,13 @@ class WorkflowAppConfigValidator: @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: """ Validate for workflow app model config :param tenant_id: tenant id :param config: app model config args + :param only_structure_validate: only validate the structure of the config """ related_config_keys = [] @@ -23,7 +24,11 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 789d74ed2cc13f..a35b0dd36ed98f 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -18,12 +18,3 @@ def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) return CompletionAppConfigValidator.config_validate(tenant_id, config) else: raise ValueError(f"Invalid app mode: {app_mode}") - - @classmethod - def validate_features(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: - if app_mode == AppMode.ADVANCED_CHAT: - return AdvancedChatAppConfigValidator.config_validate(tenant_id, config) - elif app_mode == AppMode.WORKFLOW: - return WorkflowAppConfigValidator.config_validate(tenant_id, config) - else: - raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 006bc44e41d639..102c8617331b11 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,6 +2,8 @@ from datetime import datetime from typing import Optional +from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator +from core.app.workflow.config_validator import WorkflowAppConfigValidator from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -56,7 +58,11 @@ def sync_draft_workflow(self, app_model: App, # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) - # TODO validate features + # validate features structure + self.validate_features_structure( + app_model=app_model, + features=features + ) # create draft workflow if not found if not workflow: @@ -100,7 +106,7 @@ def publish_workflow(self, app_model: App, if not draft_workflow: raise ValueError('No valid workflow found.') - # TODO check if the workflow is valid, basic check + # TODO check if the workflow structure is valid # create new workflow workflow = Workflow( @@ -153,3 +159,19 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: ) return new_app + + def validate_features_structure(self, app_model: App, features: dict) -> dict: + if app_model.mode == AppMode.ADVANCED_CHAT.value: + return AdvancedChatAppConfigValidator.config_validate( + tenant_id=app_model.tenant_id, + config=features, + only_structure_validate=True + ) + elif app_model.mode == AppMode.WORKFLOW.value: + return WorkflowAppConfigValidator.config_validate( + tenant_id=app_model.tenant_id, + config=features, + only_structure_validate=True + ) + else: + raise ValueError(f"Invalid app mode: {app_model.mode}") From 9651a208a97b4f8da32611106bd47b93eafd30e3 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:20:31 +0800 Subject: [PATCH 048/160] lint fix --- api/services/app_model_config_service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index a35b0dd36ed98f..f2caeb14ff5217 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,8 +1,6 @@ -from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator from core.app.agent_chat.config_validator import AgentChatAppConfigValidator from core.app.chat.config_validator import ChatAppConfigValidator from core.app.completion.config_validator import CompletionAppConfigValidator -from core.app.workflow.config_validator import WorkflowAppConfigValidator from models.model import AppMode From 43b0440358886d2f94ca5cc714406d45ddc55972 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:58:30 +0800 Subject: [PATCH 049/160] support workflow features --- api/controllers/console/app/audio.py | 6 +- api/controllers/console/explore/audio.py | 14 +---- api/controllers/console/explore/parameter.py | 60 ++++++++++++++------ api/controllers/service_api/app/app.py | 51 ++++++++++++----- api/controllers/service_api/app/audio.py | 16 ++---- api/controllers/web/app.py | 49 +++++++++++----- api/controllers/web/audio.py | 16 +----- api/controllers/web/site.py | 4 -- api/core/file/message_file_parser.py | 6 +- api/core/memory/token_buffer_memory.py | 7 ++- api/models/model.py | 7 ++- api/models/workflow.py | 16 ++++++ api/services/app_service.py | 7 ++- api/services/audio_service.py | 49 ++++++++++++++-- 14 files changed, 211 insertions(+), 97 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 458fa5098f80a8..c7f3a598cafb4c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -43,7 +43,7 @@ def post(self, app_model): try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=None, ) @@ -83,9 +83,9 @@ class ChatMessageTextApi(Resource): def post(self, app_model): try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index dc546ce0dd28d7..34ce1ec1eec109 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -32,16 +32,12 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=None ) @@ -76,16 +72,12 @@ def post(self, installed_app): class ChatTextApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.text_to_speech_dict['enabled']: - raise AppUnavailableError() try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) return {'data': response.data.decode('latin1')} diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c4afb0b9236651..0239742a4a70b7 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -4,9 +4,10 @@ from flask_restful import fields, marshal_with from controllers.console import api +from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from models.model import AppModelConfig, InstalledApp +from models.model import AppModelConfig, InstalledApp, AppMode from models.tools import ApiToolProvider @@ -45,30 +46,55 @@ class AppParameterApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app - app_model_config = app_model.app_model_config + + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" app_model_config: AppModelConfig = installed_app.app.app_model_config + if not app_model_config: + return { + 'tool_icons': {} + } + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} @@ -77,7 +103,7 @@ def get(self, installed_app: InstalledApp): # get all tools tools = agent_config.get('tools', []) url_prefix = (current_app.config.get("CONSOLE_API_URL") - + "/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: @@ -94,12 +120,14 @@ def get(self, installed_app: InstalledApp): ) meta['tool_icons'][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { + meta['tool_icons'][tool_name] = { "background": "#252525", "content": "\ud83d\ude01" } return meta -api.add_resource(AppParameterApi, '/installed-apps//parameters', endpoint='installed_app_parameters') + +api.add_resource(AppParameterApi, '/installed-apps//parameters', + endpoint='installed_app_parameters') api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index a3151fc4a21ea5..76708716c226e4 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -4,9 +4,10 @@ from flask_restful import fields, marshal_with, Resource from controllers.service_api import api +from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, AppMode from models.tools import ApiToolProvider @@ -46,31 +47,55 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - app_model_config = app_model.app_model_config + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class AppMetaApi(Resource): @validate_app_token def get(self, app_model: App): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config + if not app_model_config: + return { + 'tool_icons': {} + } + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index f6cad501f09264..57edab40901380 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,18 +33,13 @@ class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() - file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, - end_user=end_user.get_id() + end_user=end_user ) return response @@ -79,15 +74,16 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('text', type=str, required=True, nullable=False, location='json') + parser.add_argument('voice', type=str, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') args = parser.parse_args() try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=args['text'], - end_user=end_user.get_id(), - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + end_user=end_user, + voice=args.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 25492b11432a6e..07ce098298a364 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -4,9 +4,10 @@ from flask_restful import fields, marshal_with from controllers.web import api +from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, AppMode from models.tools import ApiToolProvider @@ -44,30 +45,52 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - app_model_config = app_model.app_model_config + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class AppMeta(WebApiResource): def get(self, app_model: App, end_user): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config + if not app_model_config: + raise AppUnavailableError() + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 4e677ae288dd4a..8b8ab8f090b531 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,16 +31,11 @@ class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() - file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=end_user ) @@ -74,17 +69,12 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.text_to_speech_dict['enabled']: - raise AppUnavailableError() - try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], end_user=end_user.external_user_id, - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index d8e2d597071889..bf3536d2766ed8 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -83,7 +83,3 @@ def __init__(self, tenant, app, site, end_user, can_replace_logo): 'remove_webapp_brand': remove_webapp_brand, 'replace_webapp_logo': replace_webapp_logo, } - - if app.enable_site and site.prompt_public: - app_model_config = app.app_model_config - self.model_config = app_model_config diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 1b7b8b87da7214..c13207357820ee 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -96,16 +96,16 @@ def validate_and_transform_files_arg(self, files: list[dict], app_model_config: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: + def transform_message_files(self, files: list[MessageFile], file_upload_config: Optional[dict]) -> list[FileObj]: """ transform message files :param files: - :param app_model_config: + :param file_upload_config: :return: """ # transform files to file objs - type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict) + type_file_objs = self._to_file_objs(files, file_upload_config) # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 4d44ac38183fb0..f9200dcc71d1f8 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -10,7 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db -from models.model import Conversation, Message +from models.model import Conversation, Message, AppMode class TokenBufferMemory: @@ -44,7 +44,10 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, files = message.message_files if files: file_objs = message_file_parser.transform_message_files( - files, message.app_model_config + files, + message.app_model_config.file_upload_dict + if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + else message.workflow_run.workflow.features_dict.get('file_upload', {}) ) if not file_objs: diff --git a/api/models/model.py b/api/models/model.py index c6409c61edd9c2..e514ea729bd2aa 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -82,9 +82,10 @@ def site(self): @property def app_model_config(self) -> Optional['AppModelConfig']: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() - return app_model_config + if self.app_model_config_id: + return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + + return None @property def workflow(self): diff --git a/api/models/workflow.py b/api/models/workflow.py index c38c1dd61079e4..ff4e944e29e2f1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -129,6 +129,22 @@ def graph_dict(self): def features_dict(self): return self.features if not self.features else json.loads(self.features) + def user_input_form(self): + # get start node from graph + if not self.graph: + return [] + + graph_dict = self.graph_dict + if 'nodes' not in graph_dict: + return [] + + start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) + if not start_node: + return [] + + # get user_input_form from start node + return start_node.get('variables', []) + class WorkflowRunTriggeredFrom(Enum): """ diff --git a/api/services/app_service.py b/api/services/app_service.py index 7dd5d770eaa1bd..e0a7835cb72a5d 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -175,12 +175,17 @@ def import_app(self, tenant_id: str, args: dict, account: Account) -> App: if workflow: # init draft workflow workflow_service = WorkflowService() - workflow_service.sync_draft_workflow( + draft_workflow = workflow_service.sync_draft_workflow( app_model=app, graph=workflow.get('graph'), features=workflow.get('features'), account=account ) + workflow_service.publish_workflow( + app_model=app, + account=account, + draft_workflow=draft_workflow + ) if model_config_data: app_model_config = AppModelConfig() diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a9fe65df6fbb25..0123666644c6af 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,6 +5,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from models.model import AppModelConfig, App, AppMode from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, @@ -20,7 +21,21 @@ class AudioService: @classmethod - def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[str] = None): + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("Speech to text is not enabled") + + features_dict = workflow.features_dict + if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + raise ValueError("Speech to text is not enabled") + else: + app_model_config: AppModelConfig = app_model.app_model_config + + if not app_model_config.speech_to_text_dict['enabled']: + raise ValueError("Speech to text is not enabled") + if file is None: raise NoAudioUploadedServiceError() @@ -37,7 +52,7 @@ def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[st model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: @@ -49,17 +64,41 @@ def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[st return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, tenant_id: str, text: str, voice: str, streaming: bool, end_user: Optional[str] = None): + def transcript_tts(cls, app_model: App, text: str, streaming: bool, end_user: Optional[str] = None): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("TTS is not enabled") + + features_dict = workflow.features_dict + if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + raise ValueError("TTS is not enabled") + + voice = features_dict['text_to_speech'].get('voice') + else: + text_to_speech_dict = app_model.app_model_config.text_to_speech_dict + + if not text_to_speech_dict.get('enabled'): + raise ValueError("TTS is not enabled") + + voice = text_to_speech_dict.get('voice'), + model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() try: - return model_instance.invoke_tts(content_text=text.strip(), user=end_user, streaming=streaming, tenant_id=tenant_id, voice=voice) + return model_instance.invoke_tts( + content_text=text.strip(), + user=end_user, + streaming=streaming, + tenant_id=app_model.tenant_id, + voice=voice + ) except Exception as e: raise e From 15c7e0ec2f2778f92c352b1373d0273afe6689f8 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:58:33 +0800 Subject: [PATCH 050/160] lint fix --- api/controllers/console/explore/audio.py | 1 - api/controllers/console/explore/parameter.py | 2 +- api/controllers/service_api/app/audio.py | 2 +- api/controllers/web/audio.py | 2 +- api/core/memory/token_buffer_memory.py | 2 +- api/services/audio_service.py | 2 +- 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 34ce1ec1eec109..f03663f1a22ea3 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,6 @@ from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import AppModelConfig from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0239742a4a70b7..9c0fca57f25a04 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -7,7 +7,7 @@ from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from models.model import AppModelConfig, InstalledApp, AppMode +from models.model import AppMode, AppModelConfig, InstalledApp from models.tools import ApiToolProvider diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 57edab40901380..15c0a153b89283 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -20,7 +20,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig, EndUser +from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 8b8ab8f090b531..e0074c452fb851 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -19,7 +19,7 @@ from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig +from models.model import App from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index f9200dcc71d1f8..00813faef7ed84 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -10,7 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db -from models.model import Conversation, Message, AppMode +from models.model import AppMode, Conversation, Message class TokenBufferMemory: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 0123666644c6af..7a658487f83a2d 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,7 +5,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from models.model import AppModelConfig, App, AppMode +from models.model import App, AppMode, AppModelConfig from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, From 3f5d1a79c664109650b435bfeee9151afff1a798 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 02:40:18 +0800 Subject: [PATCH 051/160] refactor apps --- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 6 +- api/controllers/console/app/conversation.py | 8 +- api/controllers/console/app/message.py | 4 +- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/console/explore/message.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/web/completion.py | 2 +- api/controllers/web/message.py | 2 +- api/core/agent/base_agent_runner.py | 47 +- api/core/agent/cot_agent_runner.py | 33 +- api/core/agent/entities.py | 61 +++ api/core/agent/fc_agent_runner.py | 14 +- .../app/advanced_chat/config_validator.py | 59 --- .../{advanced_chat => app_config}/__init__.py | 0 .../app/app_config/base_app_config_manager.py | 73 +++ .../common}/__init__.py | 0 .../sensitive_word_avoidance}/__init__.py | 0 .../sensitive_word_avoidance/manager.py} | 19 +- .../easy_ui_based_app}/__init__.py | 0 .../easy_ui_based_app/agent}/__init__.py | 0 .../easy_ui_based_app/agent/manager.py | 79 ++++ .../easy_ui_based_app/dataset}/__init__.py | 0 .../easy_ui_based_app/dataset/manager.py} | 87 +++- .../model_config/__init__.py | 0 .../model_config/converter.py | 104 +++++ .../model_config/manager.py} | 36 +- .../prompt_template/__init__.py | 0 .../prompt_template/manager.py} | 59 ++- .../easy_ui_based_app/variables/__init__.py | 0 .../easy_ui_based_app/variables/manager.py | 184 ++++++++ .../app_config/entities.py} | 169 ++----- api/core/app/app_config/features/__init__.py | 0 .../features/file_upload/__init__.py | 0 .../features/file_upload/manager.py} | 26 +- .../features/more_like_this/__init__.py | 0 .../features/more_like_this/manager.py} | 15 +- .../features/opening_statement/__init__.py | 0 .../features/opening_statement/manager.py} | 18 +- .../features/retrieval_resource/__init__.py | 0 .../features/retrieval_resource/manager.py} | 10 +- .../features/speech_to_text/__init__.py | 0 .../features/speech_to_text/manager.py} | 15 +- .../__init__.py | 0 .../manager.py} | 18 +- .../features/text_to_speech/__init__.py | 0 .../features/text_to_speech/manager.py} | 22 +- .../workflow_ui_based_app/__init__.py | 0 .../variables/__init__.py | 0 .../variables/manager.py | 22 + api/core/app/app_manager.py | 196 +++++--- .../app/app_orchestration_config_converter.py | 421 ------------------ api/core/app/app_queue_manager.py | 4 +- api/core/app/apps/__init__.py | 0 api/core/app/apps/advanced_chat/__init__.py | 0 .../apps/advanced_chat/app_config_manager.py | 94 ++++ api/core/app/apps/agent_chat/__init__.py | 0 .../agent_chat/app_config_manager.py} | 116 +++-- .../app/{ => apps}/agent_chat/app_runner.py | 69 +-- api/core/app/{ => apps}/base_app_runner.py | 35 +- api/core/app/apps/chat/__init__.py | 0 api/core/app/apps/chat/app_config_manager.py | 135 ++++++ api/core/app/{ => apps}/chat/app_runner.py | 61 +-- api/core/app/apps/completion/__init__.py | 0 .../app/apps/completion/app_config_manager.py | 118 +++++ .../app/{ => apps}/completion/app_runner.py | 53 +-- api/core/app/apps/workflow/__init__.py | 0 .../app/apps/workflow/app_config_manager.py | 71 +++ api/core/app/chat/config_validator.py | 82 ---- api/core/app/completion/config_validator.py | 67 --- api/core/app/entities/__init__.py | 0 api/core/app/entities/app_invoke_entities.py | 111 +++++ api/core/{ => app}/entities/queue_entities.py | 0 .../annotation_reply/annotation_reply.py | 2 +- .../hosting_moderation/hosting_moderation.py | 7 +- api/core/app/generate_task_pipeline.py | 22 +- .../app/validators/external_data_fetch.py | 39 -- api/core/app/validators/user_input_form.py | 61 --- api/core/app/workflow/config_validator.py | 39 -- .../agent_loop_gather_callback_handler.py | 262 ----------- .../callback_handler/entity/agent_loop.py | 23 - .../index_tool_callback_handler.py | 2 +- .../external_data_tool/external_data_fetch.py | 2 +- api/core/file/file_obj.py | 5 +- api/core/file/message_file_parser.py | 35 +- api/core/helper/moderation.py | 4 +- api/core/memory/token_buffer_memory.py | 20 +- api/core/moderation/input_moderation.py | 10 +- api/core/prompt/advanced_prompt_transform.py | 15 +- api/core/prompt/prompt_transform.py | 6 +- api/core/prompt/simple_prompt_transform.py | 14 +- .../rag/retrieval/agent/agent_llm_callback.py | 101 ----- api/core/rag/retrieval/agent/llm_chain.py | 7 +- .../agent/multi_dataset_router_agent.py | 6 +- .../structed_multi_dataset_router_agent.py | 4 +- .../retrieval/agent_based_dataset_executor.py | 8 +- api/core/rag/retrieval/dataset_retrieval.py | 5 +- api/core/tools/tool/dataset_retriever_tool.py | 3 +- .../deduct_quota_when_messaeg_created.py | 8 +- ...vider_last_used_at_when_messaeg_created.py | 8 +- api/models/model.py | 12 + api/models/workflow.py | 2 +- api/services/app_model_config_service.py | 12 +- api/services/completion_service.py | 147 ++---- api/services/workflow/workflow_converter.py | 46 +- api/services/workflow_service.py | 8 +- .../prompt/test_advanced_prompt_transform.py | 10 +- .../core/prompt/test_prompt_transform.py | 2 +- .../prompt/test_simple_prompt_transform.py | 6 +- .../workflow/test_workflow_converter.py | 2 +- 111 files changed, 1980 insertions(+), 1820 deletions(-) create mode 100644 api/core/agent/entities.py delete mode 100644 api/core/app/advanced_chat/config_validator.py rename api/core/app/{advanced_chat => app_config}/__init__.py (100%) create mode 100644 api/core/app/app_config/base_app_config_manager.py rename api/core/app/{agent_chat => app_config/common}/__init__.py (100%) rename api/core/app/{chat => app_config/common/sensitive_word_avoidance}/__init__.py (100%) rename api/core/app/{validators/moderation.py => app_config/common/sensitive_word_avoidance/manager.py} (64%) rename api/core/app/{completion => app_config/easy_ui_based_app}/__init__.py (100%) rename api/core/app/{validators => app_config/easy_ui_based_app/agent}/__init__.py (100%) create mode 100644 api/core/app/app_config/easy_ui_based_app/agent/manager.py rename api/core/app/{workflow => app_config/easy_ui_based_app/dataset}/__init__.py (100%) rename api/core/app/{validators/dataset_retrieval.py => app_config/easy_ui_based_app/dataset/manager.py} (63%) create mode 100644 api/core/app/app_config/easy_ui_based_app/model_config/__init__.py create mode 100644 api/core/app/app_config/easy_ui_based_app/model_config/converter.py rename api/core/app/{validators/model_validator.py => app_config/easy_ui_based_app/model_config/manager.py} (73%) create mode 100644 api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py rename api/core/app/{validators/prompt.py => app_config/easy_ui_based_app/prompt_template/manager.py} (58%) create mode 100644 api/core/app/app_config/easy_ui_based_app/variables/__init__.py create mode 100644 api/core/app/app_config/easy_ui_based_app/variables/manager.py rename api/core/{entities/application_entities.py => app/app_config/entities.py} (61%) create mode 100644 api/core/app/app_config/features/__init__.py create mode 100644 api/core/app/app_config/features/file_upload/__init__.py rename api/core/app/{validators/file_upload.py => app_config/features/file_upload/manager.py} (59%) create mode 100644 api/core/app/app_config/features/more_like_this/__init__.py rename api/core/app/{validators/more_like_this.py => app_config/features/more_like_this/manager.py} (63%) create mode 100644 api/core/app/app_config/features/opening_statement/__init__.py rename api/core/app/{validators/opening_statement.py => app_config/features/opening_statement/manager.py} (66%) create mode 100644 api/core/app/app_config/features/retrieval_resource/__init__.py rename api/core/app/{validators/retriever_resource.py => app_config/features/retrieval_resource/manager.py} (68%) create mode 100644 api/core/app/app_config/features/speech_to_text/__init__.py rename api/core/app/{validators/speech_to_text.py => app_config/features/speech_to_text/manager.py} (63%) create mode 100644 api/core/app/app_config/features/suggested_questions_after_answer/__init__.py rename api/core/app/{validators/suggested_questions.py => app_config/features/suggested_questions_after_answer/manager.py} (57%) create mode 100644 api/core/app/app_config/features/text_to_speech/__init__.py rename api/core/app/{validators/text_to_speech.py => app_config/features/text_to_speech/manager.py} (56%) create mode 100644 api/core/app/app_config/workflow_ui_based_app/__init__.py create mode 100644 api/core/app/app_config/workflow_ui_based_app/variables/__init__.py create mode 100644 api/core/app/app_config/workflow_ui_based_app/variables/manager.py delete mode 100644 api/core/app/app_orchestration_config_converter.py create mode 100644 api/core/app/apps/__init__.py create mode 100644 api/core/app/apps/advanced_chat/__init__.py create mode 100644 api/core/app/apps/advanced_chat/app_config_manager.py create mode 100644 api/core/app/apps/agent_chat/__init__.py rename api/core/app/{agent_chat/config_validator.py => apps/agent_chat/app_config_manager.py} (51%) rename api/core/app/{ => apps}/agent_chat/app_runner.py (83%) rename api/core/app/{ => apps}/base_app_runner.py (93%) create mode 100644 api/core/app/apps/chat/__init__.py create mode 100644 api/core/app/apps/chat/app_config_manager.py rename api/core/app/{ => apps}/chat/app_runner.py (76%) create mode 100644 api/core/app/apps/completion/__init__.py create mode 100644 api/core/app/apps/completion/app_config_manager.py rename api/core/app/{ => apps}/completion/app_runner.py (74%) create mode 100644 api/core/app/apps/workflow/__init__.py create mode 100644 api/core/app/apps/workflow/app_config_manager.py delete mode 100644 api/core/app/chat/config_validator.py delete mode 100644 api/core/app/completion/config_validator.py create mode 100644 api/core/app/entities/__init__.py create mode 100644 api/core/app/entities/app_invoke_entities.py rename api/core/{ => app}/entities/queue_entities.py (100%) delete mode 100644 api/core/app/validators/external_data_fetch.py delete mode 100644 api/core/app/validators/user_input_form.py delete mode 100644 api/core/app/workflow/config_validator.py delete mode 100644 api/core/callback_handler/agent_loop_gather_callback_handler.py delete mode 100644 api/core/callback_handler/entity/agent_loop.py delete mode 100644 api/core/rag/retrieval/agent/agent_llm_callback.py diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index c7f3a598cafb4c..4de4a6f3fe82e6 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -37,7 +37,7 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): file = request.files['file'] diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 0632c0439b29b5..ed1522c0cdf891 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,7 +22,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -103,7 +103,7 @@ class ChatMessageApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') @@ -168,7 +168,7 @@ class ChatMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model, task_id): account = flask_login.current_user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b808d62eb017b0..33711076f8597c 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -112,7 +112,7 @@ def get(self, app_model, conversation_id): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -133,7 +133,7 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -218,7 +218,7 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @marshal_with(conversation_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -227,7 +227,7 @@ def get(self, app_model, conversation_id): @setup_required @login_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @account_initialization_required def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index c384e878aaccf9..111ec7d787e58a 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -42,7 +42,7 @@ class ChatMessageListApi(Resource): @setup_required @login_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): @@ -194,7 +194,7 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def get(self, app_model, message_id): message_id = str(message_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index e3a5112200568b..51fe53c0ec7050 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -203,7 +203,7 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def get(self, app_model): account = current_user diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 22ea4bbac242ee..dd531974fa10cd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -22,7 +22,7 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425fa896..fdb0eae24f00d4 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -24,7 +24,7 @@ NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index fd4ce831b37081..5c488093fa7b1e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -20,7 +20,7 @@ ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index fd94ec7646bd74..785e2b8d6b9225 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -21,7 +21,7 @@ ) from controllers.web.wraps import WebApiResource from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index e03bdd63bb2a27..1acb92dbf1eda5 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -21,7 +21,7 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1474c6a4757bf5..529240aecb7b30 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,17 +5,15 @@ from mimetypes import guess_extension from typing import Optional, Union, cast +from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_queue_manager import AppQueueManager -from core.app.base_app_runner import AppRunner +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_runner import AppRunner from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - AgentEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - InvokeFrom, - ModelConfigEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, + InvokeFrom, EasyUIBasedModelConfigEntity, ) from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory @@ -50,9 +48,9 @@ class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, - application_generate_entity: ApplicationGenerateEntity, - app_orchestration_config: AppOrchestrationConfigEntity, - model_config: ModelConfigEntity, + application_generate_entity: EasyUIBasedAppGenerateEntity, + app_config: AgentChatAppConfig, + model_config: EasyUIBasedModelConfigEntity, config: AgentEntity, queue_manager: AppQueueManager, message: Message, @@ -66,7 +64,7 @@ def __init__(self, tenant_id: str, """ Agent runner :param tenant_id: tenant id - :param app_orchestration_config: app orchestration config + :param app_config: app generate entity :param model_config: model config :param config: dataset config :param queue_manager: queue manager @@ -78,7 +76,7 @@ def __init__(self, tenant_id: str, """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity - self.app_orchestration_config = app_orchestration_config + self.app_config = app_config self.model_config = model_config self.config = config self.queue_manager = queue_manager @@ -97,16 +95,16 @@ def __init__(self, tenant_id: str, # init dataset tools hit_callback = DatasetIndexToolCallbackHandler( queue_manager=queue_manager, - app_id=self.application_generate_entity.app_id, + app_id=self.app_config.app_id, message_id=message.id, user_id=user_id, invoke_from=self.application_generate_entity.invoke_from, ) self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( tenant_id=tenant_id, - dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [], - retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None, - return_resource=app_orchestration_config.show_retrieve_source, + dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], + retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, + return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback ) @@ -124,14 +122,15 @@ def __init__(self, tenant_id: str, else: self.stream_tool_call = False - def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: + def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \ + -> EasyUIBasedAppGenerateEntity: """ - Repack app orchestration config + Repack app generate entity """ - if app_orchestration_config.prompt_template.simple_prompt_template is None: - app_orchestration_config.prompt_template.simple_prompt_template = '' + if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: + app_generate_entity.app_config.prompt_template.simple_prompt_template = '' - return app_orchestration_config + return app_generate_entity def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: """ @@ -351,7 +350,7 @@ def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[ )) db.session.close() - + return result def create_agent_thought(self, message_id: str, message: str, @@ -462,7 +461,7 @@ def save_agent_thought(self, db.session.commit() db.session.close() - + def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ Transform tool message into agent thought diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 5650113f470c2f..5b345f4da0c6a6 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -5,7 +5,7 @@ from core.agent.base_agent_runner import BaseAgentRunner from core.app.app_queue_manager import PublishFrom -from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit +from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -27,7 +27,7 @@ from models.model import Conversation, Message -class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): +class CotAgentRunner(BaseAgentRunner): _is_first_iteration = True _ignore_observation_providers = ['wenxin'] @@ -39,30 +39,33 @@ def run(self, conversation: Conversation, """ Run Cot agent application """ - app_orchestration_config = self.app_orchestration_config - self._repack_app_orchestration_config(app_orchestration_config) + app_generate_entity = self.application_generate_entity + self._repack_app_generate_entity(app_generate_entity) agent_scratchpad: list[AgentScratchpadUnit] = [] self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) - if 'Observation' not in app_orchestration_config.model_config.stop: - if app_orchestration_config.model_config.provider not in self._ignore_observation_providers: - app_orchestration_config.model_config.stop.append('Observation') + # check model mode + if 'Observation' not in app_generate_entity.model_config.stop: + if app_generate_entity.model_config.provider not in self._ignore_observation_providers: + app_generate_entity.model_config.stop.append('Observation') + + app_config = self.app_config # override inputs inputs = inputs or {} - instruction = self.app_orchestration_config.prompt_template.simple_prompt_template + instruction = app_config.prompt_template.simple_prompt_template instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 - max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} - for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + for tool in app_config.agent.tools if app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -122,11 +125,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # update prompt messages prompt_messages = self._organize_cot_prompt_messages( - mode=app_orchestration_config.model_config.mode, + mode=app_generate_entity.model_config.mode, prompt_messages=prompt_messages, tools=prompt_messages_tools, agent_scratchpad=agent_scratchpad, - agent_prompt_message=app_orchestration_config.agent.prompt, + agent_prompt_message=app_config.agent.prompt, instruction=instruction, input=query ) @@ -136,9 +139,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=app_generate_entity.model_config.parameters, tools=[], - stop=app_orchestration_config.model_config.stop, + stop=app_generate_entity.model_config.stop, stream=True, user=self.user_id, callbacks=[], @@ -550,7 +553,7 @@ def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpad """ convert agent scratchpad list to str """ - next_iteration = self.app_orchestration_config.agent.prompt.next_iteration + next_iteration = self.app_config.agent.prompt.next_iteration result = '' for scratchpad in agent_scratchpad: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py new file mode 100644 index 00000000000000..0fbfdc26367237 --- /dev/null +++ b/api/core/agent/entities.py @@ -0,0 +1,61 @@ +from enum import Enum +from typing import Literal, Any, Union, Optional + +from pydantic import BaseModel + + +class AgentToolEntity(BaseModel): + """ + Agent Tool Entity. + """ + provider_type: Literal["builtin", "api"] + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] = {} + + +class AgentPromptEntity(BaseModel): + """ + Agent Prompt Entity. + """ + first_prompt: str + next_iteration: str + + +class AgentScratchpadUnit(BaseModel): + """ + Agent First Prompt Entity. + """ + + class Action(BaseModel): + """ + Action Entity. + """ + action_name: str + action_input: Union[dict, str] + + agent_response: Optional[str] = None + thought: Optional[str] = None + action_str: Optional[str] = None + observation: Optional[str] = None + action: Optional[Action] = None + + +class AgentEntity(BaseModel): + """ + Agent Entity. + """ + + class Strategy(Enum): + """ + Agent Strategy. + """ + CHAIN_OF_THOUGHT = 'chain-of-thought' + FUNCTION_CALLING = 'function-calling' + + provider: str + model: str + strategy: Strategy + prompt: Optional[AgentPromptEntity] = None + tools: list[AgentToolEntity] = None + max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9b238bf2324b84..30e5cdd6946a14 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -34,9 +34,11 @@ def run(self, conversation: Conversation, """ Run FunctionCall agent application """ - app_orchestration_config = self.app_orchestration_config + app_generate_entity = self.application_generate_entity - prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or '' + app_config = self.app_config + + prompt_template = app_config.prompt_template.simple_prompt_template or '' prompt_messages = self.history_prompt_messages prompt_messages = self.organize_prompt_messages( prompt_template=prompt_template, @@ -47,7 +49,7 @@ def run(self, conversation: Conversation, # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} - for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + for tool in app_config.agent.tools if app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -67,7 +69,7 @@ def run(self, conversation: Conversation, tool_instances[dataset_tool.identity.name] = dataset_tool iteration_step = 1 - max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 # continue to run until there is not any tool call function_call_state = True @@ -110,9 +112,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=app_generate_entity.model_config.parameters, tools=prompt_messages_tools, - stop=app_orchestration_config.model_config.stop, + stop=app_generate_entity.model_config.stop, stream=self.stream_tool_call, user=self.user_id, callbacks=[], diff --git a/api/core/app/advanced_chat/config_validator.py b/api/core/app/advanced_chat/config_validator.py deleted file mode 100644 index a20198ef4a344d..00000000000000 --- a/api/core/app/advanced_chat/config_validator.py +++ /dev/null @@ -1,59 +0,0 @@ -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.opening_statement import OpeningStatementValidator -from core.app.validators.retriever_resource import RetrieverResourceValidator -from core.app.validators.speech_to_text import SpeechToTextValidator -from core.app.validators.suggested_questions import SuggestedQuestionsValidator -from core.app.validators.text_to_speech import TextToSpeechValidator - - -class AdvancedChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: - """ - Validate for advanced chat app model config - - :param tenant_id: tenant id - :param config: app model config args - :param only_structure_validate: if True, only structure validation will be performed - """ - related_config_keys = [] - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate - ) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/app/advanced_chat/__init__.py b/api/core/app/app_config/__init__.py similarity index 100% rename from api/core/app/advanced_chat/__init__.py rename to api/core/app/app_config/__init__.py diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py new file mode 100644 index 00000000000000..b3c773203d7fbf --- /dev/null +++ b/api/core/app/app_config/base_app_config_manager.py @@ -0,0 +1,73 @@ +from typing import Union, Optional + +from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppModelConfig + + +class BaseAppConfigManager: + + @classmethod + def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: Union[AppModelConfig, dict], + config_dict: Optional[dict] = None) -> dict: + """ + Convert app model config to config dict + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + + return config_dict + + @classmethod + def convert_features(cls, config_dict: dict) -> AppAdditionalFeatures: + """ + Convert app config to app model config + + :param config_dict: app config + """ + config_dict = config_dict.copy() + + additional_features = AppAdditionalFeatures() + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( + config=config_dict + ) + + additional_features.file_upload = FileUploadConfigManager.convert( + config=config_dict + ) + + additional_features.opening_statement, additional_features.suggested_questions = \ + OpeningStatementConfigManager.convert( + config=config_dict + ) + + additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( + config=config_dict + ) + + additional_features.more_like_this = MoreLikeThisConfigManager.convert( + config=config_dict + ) + + additional_features.speech_to_text = SpeechToTextConfigManager.convert( + config=config_dict + ) + + additional_features.text_to_speech = TextToSpeechConfigManager.convert( + config=config_dict + ) + + return additional_features diff --git a/api/core/app/agent_chat/__init__.py b/api/core/app/app_config/common/__init__.py similarity index 100% rename from api/core/app/agent_chat/__init__.py rename to api/core/app/app_config/common/__init__.py diff --git a/api/core/app/chat/__init__.py b/api/core/app/app_config/common/sensitive_word_avoidance/__init__.py similarity index 100% rename from api/core/app/chat/__init__.py rename to api/core/app/app_config/common/sensitive_word_avoidance/__init__.py diff --git a/api/core/app/validators/moderation.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py similarity index 64% rename from api/core/app/validators/moderation.py rename to api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 7a5dff55c9e6e0..3dccfa3cbed2da 100644 --- a/api/core/app/validators/moderation.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,11 +1,24 @@ -import logging +from typing import Optional +from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory -logger = logging.getLogger(__name__) +class SensitiveWordAvoidanceConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + if not sensitive_word_avoidance_dict: + return None + + if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: + return SensitiveWordAvoidanceEntity( + type=sensitive_word_avoidance_dict.get('type'), + config=sensitive_word_avoidance_dict.get('config'), + ) + else: + return None -class ModerationValidator: @classmethod def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ -> tuple[dict, list[str]]: diff --git a/api/core/app/completion/__init__.py b/api/core/app/app_config/easy_ui_based_app/__init__.py similarity index 100% rename from api/core/app/completion/__init__.py rename to api/core/app/app_config/easy_ui_based_app/__init__.py diff --git a/api/core/app/validators/__init__.py b/api/core/app/app_config/easy_ui_based_app/agent/__init__.py similarity index 100% rename from api/core/app/validators/__init__.py rename to api/core/app/app_config/easy_ui_based_app/agent/__init__.py diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py new file mode 100644 index 00000000000000..b50b7f678c492d --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -0,0 +1,79 @@ +from typing import Optional + +from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity +from core.tools.prompt.template import REACT_PROMPT_TEMPLATES + + +class AgentConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[AgentEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + if 'agent_mode' in config and config['agent_mode'] \ + and 'enabled' in config['agent_mode'] \ + and config['agent_mode']['enabled']: + + agent_dict = config.get('agent_mode', {}) + agent_strategy = agent_dict.get('strategy', 'cot') + + if agent_strategy == 'function_call': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + elif agent_strategy == 'cot' or agent_strategy == 'react': + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + else: + # old configs, try to detect default strategy + if config['model']['provider'] == 'openai': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + + agent_tools = [] + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) >= 4: + if "enabled" not in tool or not tool["enabled"]: + continue + + agent_tool_properties = { + 'provider_type': tool['provider_type'], + 'provider_id': tool['provider_id'], + 'tool_name': tool['tool_name'], + 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} + } + + agent_tools.append(AgentToolEntity(**agent_tool_properties)) + + if 'strategy' in config['agent_mode'] and \ + config['agent_mode']['strategy'] not in ['react_router', 'router']: + agent_prompt = agent_dict.get('prompt', None) or {} + # check model mode + model_mode = config.get('model', {}).get('mode', 'completion') + if model_mode == 'completion': + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['completion'][ + 'agent_scratchpad']), + ) + else: + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + ) + + return AgentEntity( + provider=config['model']['provider'], + model=config['model']['name'], + strategy=strategy, + prompt=agent_prompt_entity, + tools=agent_tools, + max_iteration=agent_dict.get('max_iteration', 5) + ) + + return None diff --git a/api/core/app/workflow/__init__.py b/api/core/app/app_config/easy_ui_based_app/dataset/__init__.py similarity index 100% rename from api/core/app/workflow/__init__.py rename to api/core/app/app_config/easy_ui_based_app/dataset/__init__.py diff --git a/api/core/app/validators/dataset_retrieval.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py similarity index 63% rename from api/core/app/validators/dataset_retrieval.py rename to api/core/app/app_config/easy_ui_based_app/dataset/manager.py index fb5b64832073ae..4c08f62d27217b 100644 --- a/api/core/app/validators/dataset_retrieval.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,11 +1,94 @@ -import uuid +from typing import Optional +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.entities.agent_entities import PlanningStrategy from models.model import AppMode from services.dataset_service import DatasetService -class DatasetValidator: +class DatasetConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[DatasetEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + dataset_ids = [] + if 'datasets' in config.get('dataset_configs', {}): + datasets = config.get('dataset_configs', {}).get('datasets', { + 'strategy': 'router', + 'datasets': [] + }) + + for dataset in datasets.get('datasets', []): + keys = list(dataset.keys()) + if len(keys) == 0 or keys[0] != 'dataset': + continue + + dataset = dataset['dataset'] + + if 'enabled' not in dataset or not dataset['enabled']: + continue + + dataset_id = dataset.get('id', None) + if dataset_id: + dataset_ids.append(dataset_id) + + if 'agent_mode' in config and config['agent_mode'] \ + and 'enabled' in config['agent_mode'] \ + and config['agent_mode']['enabled']: + + agent_dict = config.get('agent_mode', {}) + + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) == 1: + # old standard + key = list(tool.keys())[0] + + if key != 'dataset': + continue + + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + continue + + dataset_id = tool_item['id'] + dataset_ids.append(dataset_id) + + if len(dataset_ids) == 0: + return None + + # dataset configs + dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'}) + query_variable = config.get('dataset_query_variable') + + if dataset_configs['retrieval_model'] == 'single': + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ) + ) + ) + else: + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ), + top_k=dataset_configs.get('top_k'), + score_threshold=dataset_configs.get('score_threshold'), + reranking_model=dataset_configs.get('reranking_model') + ) + ) + @classmethod def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py b/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py new file mode 100644 index 00000000000000..05fcb107919673 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -0,0 +1,104 @@ +from typing import cast + +from core.app.app_config.entities import EasyUIBasedAppConfig +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity + +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.provider_manager import ProviderManager + + +class EasyUIBasedModelConfigEntityConverter: + @classmethod + def convert(cls, app_config: EasyUIBasedAppConfig, + skip_check: bool = False) \ + -> EasyUIBasedModelConfigEntity: + """ + Convert app model config dict to entity. + :param app_config: app config + :param skip_check: skip check + :raises ProviderTokenNotInitError: provider token not init error + :return: app orchestration config entity + """ + model_config = app_config.model + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=app_config.tenant_id, + provider=model_config.provider, + model_type=ModelType.LLM + ) + + provider_name = provider_model_bundle.configuration.provider.provider + model_name = model_config.model + + model_type_instance = provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + # check model credentials + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, + model=model_config.model + ) + + if model_credentials is None: + if not skip_check: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + else: + model_credentials = {} + + if not skip_check: + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, + model_type=ModelType.LLM + ) + + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = model_config.parameters + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = model_config.mode + if not model_mode: + mode_enum = model_type_instance.get_model_mode( + model=model_config.model, + credentials=model_credentials + ) + + model_mode = mode_enum.value + + model_schema = model_type_instance.get_model_schema( + model_config.model, + model_credentials + ) + + if not skip_check and not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return EasyUIBasedModelConfigEntity( + provider=model_config.provider, + model=model_config.model, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) diff --git a/api/core/app/validators/model_validator.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py similarity index 73% rename from api/core/app/validators/model_validator.py rename to api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 1d86fbaf04e836..5cca2bc1a74be5 100644 --- a/api/core/app/validators/model_validator.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,40 @@ - -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.app.app_config.entities import ModelConfigEntity +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey from core.model_runtime.model_providers import model_provider_factory from core.provider_manager import ProviderManager -class ModelValidator: +class ModelConfigManager: + @classmethod + def convert(cls, config: dict) -> ModelConfigEntity: + """ + Convert model config to model config + + :param config: model config args + """ + # model config + model_config = config.get('model') + + if not model_config: + raise ValueError("model is required") + + completion_params = model_config.get('completion_params') + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = model_config.get('mode') + + return ModelConfigEntity( + provider=config['model']['provider'], + model=config['model']['name'], + mode=model_mode, + parameters=completion_params, + stop=stop, + ) + @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/prompt.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py similarity index 58% rename from api/core/app/validators/prompt.py rename to api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 288a5234155b2c..5629d0d09e4537 100644 --- a/api/core/app/validators/prompt.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,10 +1,61 @@ - -from core.entities.application_entities import PromptTemplateEntity +from core.app.app_config.entities import PromptTemplateEntity, \ + AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.model_runtime.entities.message_entities import PromptMessageRole from core.prompt.simple_prompt_transform import ModelMode from models.model import AppMode -class PromptValidator: +class PromptTemplateConfigManager: + @classmethod + def convert(cls, config: dict) -> PromptTemplateEntity: + if not config.get("prompt_type"): + raise ValueError("prompt_type is required") + + prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + simple_prompt_template = config.get("pre_prompt", "") + return PromptTemplateEntity( + prompt_type=prompt_type, + simple_prompt_template=simple_prompt_template + ) + else: + advanced_chat_prompt_template = None + chat_prompt_config = config.get("chat_prompt_config", {}) + if chat_prompt_config: + chat_prompt_messages = [] + for message in chat_prompt_config.get("prompt", []): + chat_prompt_messages.append({ + "text": message["text"], + "role": PromptMessageRole.value_of(message["role"]) + }) + + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( + messages=chat_prompt_messages + ) + + advanced_completion_prompt_template = None + completion_prompt_config = config.get("completion_prompt_config", {}) + if completion_prompt_config: + completion_prompt_template_params = { + 'prompt': completion_prompt_config['prompt']['text'], + } + + if 'conversation_histories_role' in completion_prompt_config: + completion_prompt_template_params['role_prefix'] = { + 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], + 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + } + + advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( + **completion_prompt_template_params + ) + + return PromptTemplateEntity( + prompt_type=prompt_type, + advanced_chat_prompt_template=advanced_chat_prompt_template, + advanced_completion_prompt_template=advanced_completion_prompt_template + ) + @classmethod def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ @@ -83,4 +134,4 @@ def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: if not isinstance(config["post_prompt"], str): raise ValueError("post_prompt must be of string type") - return config \ No newline at end of file + return config diff --git a/api/core/app/app_config/easy_ui_based_app/variables/__init__.py b/api/core/app/app_config/easy_ui_based_app/variables/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py new file mode 100644 index 00000000000000..ff962a5439512a --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -0,0 +1,184 @@ +import re +from typing import Tuple + +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity +from core.external_data_tool.factory import ExternalDataToolFactory + + +class BasicVariablesConfigManager: + @classmethod + def convert(cls, config: dict) -> Tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + """ + Convert model config to model config + + :param config: model config args + """ + external_data_variables = [] + variables = [] + + # old external_data_tools + external_data_tools = config.get('external_data_tools', []) + for external_data_tool in external_data_tools: + if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + continue + + external_data_variables.append( + ExternalDataVariableEntity( + variable=external_data_tool['variable'], + type=external_data_tool['type'], + config=external_data_tool['config'] + ) + ) + + # variables and external_data_tools + for variable in config.get('user_input_form', []): + typ = list(variable.keys())[0] + if typ == 'external_data_tool': + val = variable[typ] + external_data_variables.append( + ExternalDataVariableEntity( + variable=val['variable'], + type=val['type'], + config=val['config'] + ) + ) + elif typ in [ + VariableEntity.Type.TEXT_INPUT.value, + VariableEntity.Type.PARAGRAPH.value, + VariableEntity.Type.NUMBER.value, + ]: + variables.append( + VariableEntity( + type=VariableEntity.Type.value_of(typ), + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + max_length=variable[typ].get('max_length'), + default=variable[typ].get('default'), + ) + ) + elif typ == VariableEntity.Type.SELECT.value: + variables.append( + VariableEntity( + type=VariableEntity.Type.SELECT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + options=variable[typ].get('options'), + default=variable[typ].get('default'), + ) + ) + + return variables, external_data_variables + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param tenant_id: workspace id + :param config: app model config args + """ + related_config_keys = [] + config, current_related_config_keys = cls.validate_variables_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + return config, related_config_keys + + @classmethod + def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param config: app model config args + """ + if not config.get("user_input_form"): + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + + form_item = item[key] + if 'label' not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if 'variable' not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, " + "and cannot start with a number") + + variables.append(form_item["variable"]) + + if 'required' not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if 'options' not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item['default'] \ + and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + return config, ["user_input_form"] + + @classmethod + def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for external data fetch feature + + :param tenant_id: workspace id + :param config: app model config args + """ + if not config.get("external_data_tools"): + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + typ = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["external_data_tools"] \ No newline at end of file diff --git a/api/core/entities/application_entities.py b/api/core/app/app_config/entities.py similarity index 61% rename from api/core/entities/application_entities.py rename to api/core/app/app_config/entities.py index f5ea4d1eb0f8e0..e155dc1c4dd388 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/app/app_config/entities.py @@ -1,12 +1,10 @@ from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Optional from pydantic import BaseModel -from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import AIModelEntity +from models.model import AppMode class ModelConfigEntity(BaseModel): @@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel): """ provider: str model: str - model_schema: Optional[AIModelEntity] = None - mode: str - provider_model_bundle: ProviderModelBundle - credentials: dict[str, Any] = {} + mode: Optional[str] = None parameters: dict[str, Any] = {} stop: list[str] = [] @@ -194,149 +189,53 @@ class FileUploadEntity(BaseModel): image_config: Optional[dict[str, Any]] = None -class AgentToolEntity(BaseModel): - """ - Agent Tool Entity. - """ - provider_type: Literal["builtin", "api"] - provider_id: str - tool_name: str - tool_parameters: dict[str, Any] = {} - - -class AgentPromptEntity(BaseModel): - """ - Agent Prompt Entity. - """ - first_prompt: str - next_iteration: str - - -class AgentScratchpadUnit(BaseModel): - """ - Agent First Prompt Entity. - """ - - class Action(BaseModel): - """ - Action Entity. - """ - action_name: str - action_input: Union[dict, str] - - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None - - -class AgentEntity(BaseModel): - """ - Agent Entity. - """ - - class Strategy(Enum): - """ - Agent Strategy. - """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' - - provider: str - model: str - strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None - max_iteration: int = 5 - - -class AppOrchestrationConfigEntity(BaseModel): - """ - App Orchestration Config Entity. - """ - model_config: ModelConfigEntity - prompt_template: PromptTemplateEntity - variables: list[VariableEntity] = [] - external_data_variables: list[ExternalDataVariableEntity] = [] - agent: Optional[AgentEntity] = None - - # features - dataset: Optional[DatasetEntity] = None +class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileUploadEntity] = None opening_statement: Optional[str] = None + suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False text_to_speech: Optional[TextToSpeechEntity] = None - sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None -class InvokeFrom(Enum): +class AppConfig(BaseModel): """ - Invoke From. + Application Config Entity. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' - - @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid invoke from value {value}') - - def to_source(self) -> str: - """ - Get source of invoke from. + tenant_id: str + app_id: str + app_mode: AppMode + additional_features: AppAdditionalFeatures + variables: list[VariableEntity] = [] + sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None - :return: source - """ - if self == InvokeFrom.WEB_APP: - return 'web_app' - elif self == InvokeFrom.DEBUGGER: - return 'dev' - elif self == InvokeFrom.EXPLORE: - return 'explore_app' - elif self == InvokeFrom.SERVICE_API: - return 'api' - return 'dev' +class EasyUIBasedAppModelConfigFrom(Enum): + """ + App Model Config From. + """ + ARGS = 'args' + APP_LATEST_CONFIG = 'app-latest-config' + CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' -class ApplicationGenerateEntity(BaseModel): +class EasyUIBasedAppConfig(AppConfig): """ - Application Generate Entity. + Easy UI Based App Config Entity. """ - task_id: str - tenant_id: str - - app_id: str + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - # for save app_model_config_dict: dict - app_model_config_override: bool - - # Converted from app_model_config to Entity object, or directly covered by external input - app_orchestration_config_entity: AppOrchestrationConfigEntity - - conversation_id: Optional[str] = None - inputs: dict[str, str] - query: Optional[str] = None - files: list[FileObj] = [] - user_id: str - # extras - stream: bool - invoke_from: InvokeFrom - - # extra parameters, like: auto_generate_conversation_name - extras: dict[str, Any] = {} + model: ModelConfigEntity + prompt_template: PromptTemplateEntity + dataset: Optional[DatasetEntity] = None + external_data_variables: list[ExternalDataVariableEntity] = [] + + +class WorkflowUIBasedAppConfig(AppConfig): + """ + Workflow UI Based App Config Entity. + """ + workflow_id: str diff --git a/api/core/app/app_config/features/__init__.py b/api/core/app/app_config/features/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/file_upload/__init__.py b/api/core/app/app_config/features/file_upload/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/file_upload.py b/api/core/app/app_config/features/file_upload/manager.py similarity index 59% rename from api/core/app/validators/file_upload.py rename to api/core/app/app_config/features/file_upload/manager.py index 419465bd5119ab..63830696ffd28c 100644 --- a/api/core/app/validators/file_upload.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,6 +1,30 @@ +from typing import Optional +from core.app.app_config.entities import FileUploadEntity + + +class FileUploadConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[FileUploadEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + file_upload_dict = config.get('file_upload') + if file_upload_dict: + if 'image' in file_upload_dict and file_upload_dict['image']: + if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + return FileUploadEntity( + image_config={ + 'number_limits': file_upload_dict['image']['number_limits'], + 'detail': file_upload_dict['image']['detail'], + 'transfer_methods': file_upload_dict['image']['transfer_methods'] + } + ) + + return None -class FileUploadValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/more_like_this/__init__.py b/api/core/app/app_config/features/more_like_this/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/more_like_this.py b/api/core/app/app_config/features/more_like_this/manager.py similarity index 63% rename from api/core/app/validators/more_like_this.py rename to api/core/app/app_config/features/more_like_this/manager.py index 1c1bac9de64431..ec2a9a679611c7 100644 --- a/api/core/app/validators/more_like_this.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -1,6 +1,19 @@ +class MoreLikeThisConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + more_like_this = False + more_like_this_dict = config.get('more_like_this') + if more_like_this_dict: + if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: + more_like_this = True + return more_like_this -class MoreLikeThisValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/opening_statement/__init__.py b/api/core/app/app_config/features/opening_statement/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/opening_statement.py b/api/core/app/app_config/features/opening_statement/manager.py similarity index 66% rename from api/core/app/validators/opening_statement.py rename to api/core/app/app_config/features/opening_statement/manager.py index f919230e0d1611..6183c6e7493311 100644 --- a/api/core/app/validators/opening_statement.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,6 +1,22 @@ +from typing import Tuple -class OpeningStatementValidator: +class OpeningStatementConfigManager: + @classmethod + def convert(cls, config: dict) -> Tuple[str, list]: + """ + Convert model config to model config + + :param config: model config args + """ + # opening statement + opening_statement = config.get('opening_statement') + + # suggested questions + suggested_questions_list = config.get('suggested_questions') + + return opening_statement, suggested_questions_list + @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/retrieval_resource/__init__.py b/api/core/app/app_config/features/retrieval_resource/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/retriever_resource.py b/api/core/app/app_config/features/retrieval_resource/manager.py similarity index 68% rename from api/core/app/validators/retriever_resource.py rename to api/core/app/app_config/features/retrieval_resource/manager.py index 32725c74328fbd..0694cb954e47e9 100644 --- a/api/core/app/validators/retriever_resource.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -1,6 +1,14 @@ +class RetrievalResourceConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + show_retrieve_source = False + retriever_resource_dict = config.get('retriever_resource') + if retriever_resource_dict: + if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: + show_retrieve_source = True + return show_retrieve_source -class RetrieverResourceValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/speech_to_text/__init__.py b/api/core/app/app_config/features/speech_to_text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/speech_to_text.py b/api/core/app/app_config/features/speech_to_text/manager.py similarity index 63% rename from api/core/app/validators/speech_to_text.py rename to api/core/app/app_config/features/speech_to_text/manager.py index 92a1b25ae69085..b98699bfffdc87 100644 --- a/api/core/app/validators/speech_to_text.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -1,6 +1,19 @@ +class SpeechToTextConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + speech_to_text = False + speech_to_text_dict = config.get('speech_to_text') + if speech_to_text_dict: + if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: + speech_to_text = True + return speech_to_text -class SpeechToTextValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py b/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/suggested_questions.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py similarity index 57% rename from api/core/app/validators/suggested_questions.py rename to api/core/app/app_config/features/suggested_questions_after_answer/manager.py index 9161b316781489..5aacd3b32d3e52 100644 --- a/api/core/app/validators/suggested_questions.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -1,6 +1,19 @@ +class SuggestedQuestionsAfterAnswerConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + suggested_questions_after_answer = False + suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + if suggested_questions_after_answer_dict: + if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: + suggested_questions_after_answer = True + return suggested_questions_after_answer -class SuggestedQuestionsValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ @@ -16,7 +29,8 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: + if "enabled" not in config["suggested_questions_after_answer"] or not \ + config["suggested_questions_after_answer"]["enabled"]: config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/__init__.py b/api/core/app/app_config/features/text_to_speech/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/validators/text_to_speech.py b/api/core/app/app_config/features/text_to_speech/manager.py similarity index 56% rename from api/core/app/validators/text_to_speech.py rename to api/core/app/app_config/features/text_to_speech/manager.py index 182a912d52c692..1ff31034ad48e8 100644 --- a/api/core/app/validators/text_to_speech.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -1,6 +1,26 @@ +from core.app.app_config.entities import TextToSpeechEntity -class TextToSpeechValidator: +class TextToSpeechConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + text_to_speech = False + text_to_speech_dict = config.get('text_to_speech') + if text_to_speech_dict: + if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: + text_to_speech = TextToSpeechEntity( + enabled=text_to_speech_dict.get('enabled'), + voice=text_to_speech_dict.get('voice'), + language=text_to_speech_dict.get('language'), + ) + + return text_to_speech + @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/workflow_ui_based_app/__init__.py b/api/core/app/app_config/workflow_ui_based_app/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py b/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py new file mode 100644 index 00000000000000..4b117d87f8c157 --- /dev/null +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -0,0 +1,22 @@ +from core.app.app_config.entities import VariableEntity +from models.workflow import Workflow + + +class WorkflowVariablesConfigManager: + @classmethod + def convert(cls, workflow: Workflow) -> list[VariableEntity]: + """ + Convert workflow start variables to variables + + :param workflow: workflow instance + """ + variables = [] + + # find start node + user_input_form = workflow.user_input_form() + + # variables + for variable in user_input_form: + variables.append(VariableEntity(**variable)) + + return variables diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py index 86c8d2cfc735b2..98ebe2c87d8ab8 100644 --- a/api/core/app/app_manager.py +++ b/api/core/app/app_manager.py @@ -8,13 +8,18 @@ from flask import Flask, current_app from pydantic import ValidationError -from core.app.agent_chat.app_runner import AgentChatAppRunner -from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, EasyUIBasedAppConfig, VariableEntity +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom -from core.app.chat.app_runner import ChatAppRunner +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.generate_task_pipeline import GenerateTaskPipeline -from core.entities.application_entities import ( - ApplicationGenerateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, InvokeFrom, ) from core.file.file_obj import FileObj @@ -23,24 +28,19 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile +from models.model import App, Conversation, EndUser, Message, MessageFile, AppMode, AppModelConfig logger = logging.getLogger(__name__) -class AppManager: - """ - This class is responsible for managing application - """ +class EasyUIBasedAppManager: - def generate(self, tenant_id: str, - app_id: str, - app_model_config_id: str, - app_model_config_dict: dict, - app_model_config_override: bool, + def generate(self, app_model: App, + app_model_config: AppModelConfig, user: Union[Account, EndUser], invoke_from: InvokeFrom, inputs: dict[str, str], + app_model_config_dict: Optional[dict] = None, query: Optional[str] = None, files: Optional[list[FileObj]] = None, conversation: Optional[Conversation] = None, @@ -50,14 +50,12 @@ def generate(self, tenant_id: str, """ Generate App response. - :param tenant_id: workspace ID - :param app_id: app ID - :param app_model_config_id: app model config id - :param app_model_config_dict: app model config dict - :param app_model_config_override: app model config override + :param app_model: App + :param app_model_config: app model config :param user: account or end user :param invoke_from: invoke from source :param inputs: inputs + :param app_model_config_dict: app model config dict :param query: query :param files: file obj list :param conversation: conversation @@ -67,20 +65,21 @@ def generate(self, tenant_id: str, # init task id task_id = str(uuid.uuid4()) + # convert to app config + app_config = self.convert_to_app_config( + app_model=app_model, + app_model_config=app_model_config, + app_model_config_dict=app_model_config_dict, + conversation=conversation + ) + # init application generate entity - application_generate_entity = ApplicationGenerateEntity( + application_generate_entity = EasyUIBasedAppGenerateEntity( task_id=task_id, - tenant_id=tenant_id, - app_id=app_id, - app_model_config_id=app_model_config_id, - app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict( - tenant_id=tenant_id, - app_model_config_dict=app_model_config_dict - ), - app_model_config_override=app_model_config_override, + app_config=app_config, + model_config=EasyUIBasedModelConfigEntityConverter.convert(app_config), conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else inputs, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), query=query.replace('\x00', '') if query else None, files=files if files else [], user_id=user.id, @@ -89,7 +88,7 @@ def generate(self, tenant_id: str, extras=extras ) - if not stream and application_generate_entity.app_orchestration_config_entity.agent: + if not stream and application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: raise ValueError("Agent app is not supported in blocking mode.") # init generate records @@ -128,8 +127,85 @@ def generate(self, tenant_id: str, stream=stream ) + def convert_to_app_config(self, app_model: App, + app_model_config: AppModelConfig, + app_model_config_dict: Optional[dict] = None, + conversation: Optional[Conversation] = None) -> EasyUIBasedAppConfig: + if app_model_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_model.mode = AppMode.AGENT_CHAT.value + app_config = AgentChatAppConfigManager.config_convert( + app_model=app_model, + config_from=config_from, + app_model_config=app_model_config, + config_dict=app_model_config_dict + ) + elif app_mode == AppMode.CHAT: + app_config = ChatAppConfigManager.config_convert( + app_model=app_model, + config_from=config_from, + app_model_config=app_model_config, + config_dict=app_model_config_dict + ) + elif app_mode == AppMode.COMPLETION: + app_config = CompletionAppConfigManager.config_convert( + app_model=app_model, + config_from=config_from, + app_model_config=app_model_config, + config_dict=app_model_config_dict + ) + else: + raise ValueError("Invalid app mode") + + return app_config + + def _get_cleaned_inputs(self, user_inputs: dict, app_config: EasyUIBasedAppConfig): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + # Filter input variables from form configuration, handle required fields, default values, and option values + variables = app_config.variables + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"{variable} is required in input form") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs + def _generate_worker(self, flask_app: Flask, - application_generate_entity: ApplicationGenerateEntity, + application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, message_id: str) -> None: @@ -148,7 +224,7 @@ def _generate_worker(self, flask_app: Flask, conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - if application_generate_entity.app_orchestration_config_entity.agent: + if application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: # agent app runner = AgentChatAppRunner() runner.run( @@ -157,8 +233,8 @@ def _generate_worker(self, flask_app: Flask, conversation=conversation, message=message ) - else: - # basic app + elif application_generate_entity.app_config.app_mode == AppMode.CHAT: + # chatbot app runner = ChatAppRunner() runner.run( application_generate_entity=application_generate_entity, @@ -166,6 +242,16 @@ def _generate_worker(self, flask_app: Flask, conversation=conversation, message=message ) + elif application_generate_entity.app_config.app_mode == AppMode.COMPLETION: + # completion app + runner = CompletionAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message=message + ) + else: + raise ValueError("Invalid app mode") except ConversationTaskStoppedException: pass except InvokeAuthorizationError: @@ -184,7 +270,7 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.remove() - def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, + def _handle_response(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -217,24 +303,24 @@ def _handle_response(self, application_generate_entity: ApplicationGenerateEntit finally: db.session.remove() - def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ + def _init_generate_records(self, application_generate_entity: EasyUIBasedAppGenerateEntity) \ -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity :return: """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - - model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance + model_type_instance = application_generate_entity.model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) model_schema = model_type_instance.get_model_schema( - model=app_orchestration_config_entity.model_config.model, - credentials=app_orchestration_config_entity.model_config.credentials + model=application_generate_entity.model_config.model, + credentials=application_generate_entity.model_config.credentials ) + app_config = application_generate_entity.app_config + app_record = (db.session.query(App) - .filter(App.id == application_generate_entity.app_id).first()) + .filter(App.id == app_config.app_id).first()) app_mode = app_record.mode @@ -249,8 +335,8 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat account_id = application_generate_entity.user_id override_model_configs = None - if application_generate_entity.app_model_config_override: - override_model_configs = application_generate_entity.app_model_config_dict + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS: + override_model_configs = app_config.app_model_config_dict introduction = '' if app_mode == 'chat': @@ -260,9 +346,9 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat if not application_generate_entity.conversation_id: conversation = Conversation( app_id=app_record.id, - app_model_config_id=application_generate_entity.app_model_config_id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, + app_model_config_id=app_config.app_model_config_id, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_mode, name='New conversation', @@ -291,8 +377,8 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat message = Message( app_id=app_record.id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, conversation_id=conversation.id, inputs=application_generate_entity.inputs, @@ -311,7 +397,7 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat from_source=from_source, from_end_user_id=end_user_id, from_account_id=account_id, - agent_based=app_orchestration_config_entity.agent is not None + agent_based=app_config.app_mode == AppMode.AGENT_CHAT, ) db.session.add(message) @@ -333,14 +419,14 @@ def _init_generate_records(self, application_generate_entity: ApplicationGenerat return conversation, message - def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: + def _get_conversation_introduction(self, application_generate_entity: EasyUIBasedAppGenerateEntity) -> str: """ Get conversation introduction :param application_generate_entity: application generate entity :return: conversation introduction """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - introduction = app_orchestration_config_entity.opening_statement + app_config = application_generate_entity.app_config + introduction = app_config.additional_features.opening_statement if introduction: try: diff --git a/api/core/app/app_orchestration_config_converter.py b/api/core/app/app_orchestration_config_converter.py deleted file mode 100644 index 1d429ee6d91a46..00000000000000 --- a/api/core/app/app_orchestration_config_converter.py +++ /dev/null @@ -1,421 +0,0 @@ -from typing import cast - -from core.entities.application_entities import ( - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - AgentEntity, - AgentPromptEntity, - AgentToolEntity, - AppOrchestrationConfigEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - ModelConfigEntity, - PromptTemplateEntity, - SensitiveWordAvoidanceEntity, - TextToSpeechEntity, - VariableEntity, -) -from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.provider_manager import ProviderManager -from core.tools.prompt.template import REACT_PROMPT_TEMPLATES - - -class AppOrchestrationConfigConverter: - @classmethod - def convert_from_app_model_config_dict(cls, tenant_id: str, - app_model_config_dict: dict, - skip_check: bool = False) \ - -> AppOrchestrationConfigEntity: - """ - Convert app model config dict to entity. - :param tenant_id: tenant ID - :param app_model_config_dict: app model config dict - :param skip_check: skip check - :raises ProviderTokenNotInitError: provider token not init error - :return: app orchestration config entity - """ - properties = {} - - copy_app_model_config_dict = app_model_config_dict.copy() - - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=copy_app_model_config_dict['model']['provider'], - model_type=ModelType.LLM - ) - - provider_name = provider_model_bundle.configuration.provider.provider - model_name = copy_app_model_config_dict['model']['name'] - - model_type_instance = provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # check model credentials - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=copy_app_model_config_dict['model']['name'] - ) - - if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} - - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=copy_app_model_config_dict['model']['name'], - model_type=ModelType.LLM - ) - - if provider_model is None: - model_name = copy_app_model_config_dict['model']['name'] - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = copy_app_model_config_dict['model'].get('completion_params') - stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] - - # get model mode - model_mode = copy_app_model_config_dict['model'].get('mode') - if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=copy_app_model_config_dict['model']['name'], - credentials=model_credentials - ) - - model_mode = mode_enum.value - - model_schema = model_type_instance.get_model_schema( - copy_app_model_config_dict['model']['name'], - model_credentials - ) - - if not skip_check and not model_schema: - raise ValueError(f"Model {model_name} not exist.") - - properties['model_config'] = ModelConfigEntity( - provider=copy_app_model_config_dict['model']['provider'], - model=copy_app_model_config_dict['model']['name'], - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - # prompt template - prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type']) - if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: - simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "") - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) - else: - advanced_chat_prompt_template = None - chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {}) - if chat_prompt_config: - chat_prompt_messages = [] - for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) - - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) - - advanced_completion_prompt_template = None - completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {}) - if completion_prompt_config: - completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], - } - - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] - } - - advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( - **completion_prompt_template_params - ) - - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template - ) - - # external data variables - properties['external_data_variables'] = [] - - # old external_data_tools - external_data_tools = copy_app_model_config_dict.get('external_data_tools', []) - for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: - continue - - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] - ) - ) - - properties['variables'] = [] - - # variables and external_data_tools - for variable in copy_app_model_config_dict.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] - ) - ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, - ]: - properties['variables'].append( - VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - properties['variables'].append( - VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), - ) - ) - - # show retrieve source - show_retrieve_source = False - retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource') - if retriever_resource_dict: - if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: - show_retrieve_source = True - - properties['show_retrieve_source'] = show_retrieve_source - - dataset_ids = [] - if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): - datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) - - for dataset in datasets.get('datasets', []): - keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': - continue - dataset = dataset['dataset'] - - if 'enabled' not in dataset or not dataset['enabled']: - continue - - dataset_id = dataset.get('id', None) - if dataset_id: - dataset_ids.append(dataset_id) - - if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ - and 'enabled' in copy_app_model_config_dict['agent_mode'] \ - and copy_app_model_config_dict['agent_mode']['enabled']: - - agent_dict = copy_app_model_config_dict.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if copy_app_model_config_dict['model']['provider'] == 'openai': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - - agent_tools = [] - for tool in agent_dict.get('tools', []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: - continue - - agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} - } - - agent_tools.append(AgentToolEntity(**agent_tool_properties)) - elif len(keys) == 1: - # old standard - key = list(tool.keys())[0] - - if key != 'dataset': - continue - - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - continue - - dataset_id = tool_item['id'] - dataset_ids.append(dataset_id) - - if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ - copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} - # check model mode - model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), - ) - else: - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), - ) - - properties['agent'] = AgentEntity( - provider=properties['model_config'].provider, - model=properties['model_config'].model, - strategy=strategy, - prompt=agent_prompt_entity, - tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) - ) - - if len(dataset_ids) > 0: - # dataset configs - dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) - query_variable = copy_app_model_config_dict.get('dataset_query_variable') - - if dataset_configs['retrieval_model'] == 'single': - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) - ) - else: - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - top_k=dataset_configs.get('top_k'), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model') - ) - ) - - # file upload - file_upload_dict = copy_app_model_config_dict.get('file_upload') - if file_upload_dict: - if 'image' in file_upload_dict and file_upload_dict['image']: - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: - properties['file_upload'] = FileUploadEntity( - image_config={ - 'number_limits': file_upload_dict['image']['number_limits'], - 'detail': file_upload_dict['image']['detail'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] - } - ) - - # opening statement - properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement') - - # suggested questions after answer - suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer') - if suggested_questions_after_answer_dict: - if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: - properties['suggested_questions_after_answer'] = True - - # more like this - more_like_this_dict = copy_app_model_config_dict.get('more_like_this') - if more_like_this_dict: - if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: - properties['more_like_this'] = True - - # speech to text - speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text') - if speech_to_text_dict: - if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: - properties['speech_to_text'] = True - - # text to speech - text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech') - if text_to_speech_dict: - if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: - properties['text_to_speech'] = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), - ) - - # sensitive word avoidance - sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance') - if sensitive_word_avoidance_dict: - if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: - properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), - ) - - return AppOrchestrationConfigEntity(**properties) diff --git a/api/core/app/app_queue_manager.py b/api/core/app/app_queue_manager.py index c09cae3245ce1b..4bd491269cebb7 100644 --- a/api/core/app/app_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -6,8 +6,8 @@ from sqlalchemy.orm import DeclarativeMeta -from core.entities.application_entities import InvokeFrom -from core.entities.queue_entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent, diff --git a/api/core/app/apps/__init__.py b/api/core/app/apps/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/advanced_chat/__init__.py b/api/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py new file mode 100644 index 00000000000000..ab7857c4adad7d --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -0,0 +1,94 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import AppMode, App +from models.workflow import Workflow + + +class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): + """ + Advanced Chatbot App Config Entity. + """ + pass + + +class AdvancedChatAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: + features_dict = workflow.features_dict + + app_config = AdvancedChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=features_dict + ), + variables=WorkflowVariablesConfigManager.convert( + workflow=workflow + ), + additional_features=cls.convert_features(features_dict) + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for advanced chat app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: if True, only structure validation will be performed + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + diff --git a/api/core/app/apps/agent_chat/__init__.py b/api/core/app/apps/agent_chat/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/apps/agent_chat/app_config_manager.py similarity index 51% rename from api/core/app/agent_chat/config_validator.py rename to api/core/app/apps/agent_chat/app_config_manager.py index 82bc40bd9bfd7f..96dac4bd010f8e 100644 --- a/api/core/app/agent_chat/config_validator.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,24 +1,82 @@ import uuid - -from core.app.validators.dataset_retrieval import DatasetValidator -from core.app.validators.external_data_fetch import ExternalDataFetchValidator -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.model_validator import ModelValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.opening_statement import OpeningStatementValidator -from core.app.validators.prompt import PromptValidator -from core.app.validators.retriever_resource import RetrieverResourceValidator -from core.app.validators.speech_to_text import SpeechToTextValidator -from core.app.validators.suggested_questions import SuggestedQuestionsValidator -from core.app.validators.text_to_speech import TextToSpeechValidator -from core.app.validators.user_input_form import UserInputFormValidator +from typing import Optional + +from core.agent.entities import AgentEntity +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, DatasetEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode +from models.model import AppMode, App, AppModelConfig OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] -class AgentChatAppConfigValidator: +class AgentChatAppConfig(EasyUIBasedAppConfig): + """ + Agent Chatbot App Config Entity. + """ + agent: Optional[AgentEntity] = None + + +class AgentChatAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, + config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: AppModelConfig, + config_dict: Optional[dict] = None) -> AgentChatAppConfig: + """ + Convert app model config to agent chat app config + :param app_model: app model + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + + app_config = AgentChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + agent=AgentConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + @classmethod def config_validate(cls, tenant_id: str, config: dict) -> dict: """ @@ -32,23 +90,19 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys = [] # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) related_config_keys.extend(current_related_config_keys) # agent_mode @@ -56,27 +110,29 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) related_config_keys.extend(current_related_config_keys) # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -143,7 +199,7 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t except ValueError: raise ValueError("id in dataset must be of UUID type") - if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + if not DatasetConfigManager.is_dataset_exists(tenant_id, tool_item["id"]): raise ValueError("Dataset ID does not exist, please check your permission.") else: # latest style, use key-value pair diff --git a/api/core/app/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py similarity index 83% rename from api/core/app/agent_chat/app_runner.py rename to api/core/app/apps/agent_chat/app_runner.py index 38789348ad749d..2f1de8f108afba 100644 --- a/api/core/app/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -2,10 +2,12 @@ from typing import cast from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.entities import AgentEntity from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.base_app_runner import AppRunner -from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -24,7 +26,7 @@ class AgentChatAppRunner(AppRunner): """ Agent Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, + def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -36,12 +38,13 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(AgentChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -53,8 +56,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -64,22 +67,22 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if application_generate_entity.conversation_id: # get memory of conversation (read-only) model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) memory = TokenBufferMemory( conversation=conversation, model_instance=model_instance ) - + # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) # memory(optional) prompt_messages, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -91,15 +94,15 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -123,7 +126,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, ) self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, stream=application_generate_entity.stream @@ -131,7 +134,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -146,8 +149,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional), external data, dataset context(optional) prompt_messages, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -164,25 +167,25 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if hosting_moderation_result: return - agent_entity = app_orchestration_config.agent + agent_entity = app_config.agent # load tool variables tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, user_id=application_generate_entity.user_id, - tenant_id=application_generate_entity.tenant_id) + tenant_id=app_config.tenant_id) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) # init model instance model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -203,10 +206,10 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: assistant_cot_runner = CotAgentRunner( - tenant_id=application_generate_entity.tenant_id, + tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, + app_config=app_config, + model_config=application_generate_entity.model_config, config=agent_entity, queue_manager=queue_manager, message=message, @@ -225,10 +228,10 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: assistant_fc_runner = FunctionCallAgentRunner( - tenant_id=application_generate_entity.tenant_id, + tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, + app_config=app_config, + model_config=application_generate_entity.model_config, config=agent_entity, queue_manager=queue_manager, message=message, @@ -289,7 +292,7 @@ def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversation 'pool': db_variables.variables }) - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, + def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity, message: Message) -> LLMUsage: """ Get usage of all agent thoughts diff --git a/api/core/app/base_app_runner.py b/api/core/app/apps/base_app_runner.py similarity index 93% rename from api/core/app/base_app_runner.py rename to api/core/app/apps/base_app_runner.py index 2760d04180a00c..93f819af0883ed 100644 --- a/api/core/app/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -2,16 +2,13 @@ from collections.abc import Generator from typing import Optional, Union, cast +from core.app.app_config.entities import PromptTemplateEntity, ExternalDataVariableEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature -from core.entities.application_entities import ( - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - ExternalDataVariableEntity, - InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, + InvokeFrom, EasyUIBasedModelConfigEntity, ) from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileObj @@ -29,7 +26,7 @@ class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -85,7 +82,7 @@ def get_pre_calculate_rest_tokens(self, app_record: App, return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigEntity, + def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance @@ -121,7 +118,7 @@ def recalc_llm_max_tokens(self, model_config: ModelConfigEntity, model_config.parameters[parameter_rule.name] = max_tokens def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -170,7 +167,7 @@ def organize_prompt_messages(self, app_record: App, return prompt_messages, stop def direct_output(self, queue_manager: AppQueueManager, - app_orchestration_config: AppOrchestrationConfigEntity, + app_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list, text: str, stream: bool, @@ -178,7 +175,7 @@ def direct_output(self, queue_manager: AppQueueManager, """ Direct output :param queue_manager: application queue manager - :param app_orchestration_config: app orchestration config + :param app_generate_entity: app generate entity :param prompt_messages: prompt messages :param text: text :param stream: stream @@ -189,7 +186,7 @@ def direct_output(self, queue_manager: AppQueueManager, index = 0 for token in text: queue_manager.publish_chunk_message(LLMResultChunk( - model=app_orchestration_config.model_config.model, + model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, @@ -201,7 +198,7 @@ def direct_output(self, queue_manager: AppQueueManager, queue_manager.publish_message_end( llm_result=LLMResult( - model=app_orchestration_config.model_config.model, + model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage if usage else LLMUsage.empty_usage() @@ -294,14 +291,14 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, def moderation_for_inputs(self, app_id: str, tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, + app_generate_entity: EasyUIBasedAppGenerateEntity, inputs: dict, query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity + :param app_generate_entity: app generate entity :param inputs: inputs :param query: query :return: @@ -310,12 +307,12 @@ def moderation_for_inputs(self, app_id: str, return moderation_feature.check( app_id=app_id, tenant_id=tenant_id, - app_orchestration_config_entity=app_orchestration_config_entity, + app_config=app_generate_entity.app_config, inputs=inputs, query=query, ) - def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: """ @@ -334,7 +331,7 @@ def check_hosting_moderation(self, application_generate_entity: ApplicationGener if moderation_result: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=application_generate_entity.app_orchestration_config_entity, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text="I apologize for any confusion, " \ "but I'm an AI assistant to be helpful, harmless, and honest.", diff --git a/api/core/app/apps/chat/__init__.py b/api/core/app/apps/chat/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py new file mode 100644 index 00000000000000..62b2aaae5a5b2b --- /dev/null +++ b/api/core/app/apps/chat/app_config_manager.py @@ -0,0 +1,135 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppMode, App, AppModelConfig + + +class ChatAppConfig(EasyUIBasedAppConfig): + """ + Chatbot App Config Entity. + """ + pass + + +class ChatAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, + config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: AppModelConfig, + config_dict: Optional[dict] = None) -> ChatAppConfig: + """ + Convert app model config to chat app config + :param app_model: app model + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + + app_config = ChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, + config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py similarity index 76% rename from api/core/app/chat/app_runner.py rename to api/core/app/apps/chat/app_runner.py index 4c8018572e69b2..403a2d4476e304 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,10 +1,12 @@ import logging +from typing import cast from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.base_app_runner import AppRunner +from core.app.apps.chat.app_config_manager import ChatAppConfig +from core.app.apps.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - ApplicationGenerateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -21,7 +23,7 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, + def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -33,12 +35,13 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(ChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -50,8 +53,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -61,8 +64,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if application_generate_entity.conversation_id: # get memory of conversation (read-only) model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) memory = TokenBufferMemory( @@ -75,8 +78,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -88,15 +91,15 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -120,7 +123,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, ) self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, stream=application_generate_entity.stream @@ -128,7 +131,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -140,7 +143,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # get context from datasets context = None - if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, app_record.id, @@ -152,11 +155,11 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - model_config=app_orchestration_config.model_config, - config=app_orchestration_config.dataset, + model_config=application_generate_entity.model_config, + config=app_config.dataset, query=query, invoke_from=application_generate_entity.invoke_from, - show_retrieve_source=app_orchestration_config.show_retrieve_source, + show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, memory=memory ) @@ -166,8 +169,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional), external data, dataset context(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -186,22 +189,22 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=app_orchestration_config.model_config, + self.recale_llm_max_tokens( + model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) # Invoke model model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) db.session.close() invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=application_generate_entity.model_config.parameters, stop=stop, stream=application_generate_entity.stream, user=application_generate_entity.user_id, diff --git a/api/core/app/apps/completion/__init__.py b/api/core/app/apps/completion/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py new file mode 100644 index 00000000000000..b920f369b5e93f --- /dev/null +++ b/api/core/app/apps/completion/app_config_manager.py @@ -0,0 +1,118 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppMode, App, AppModelConfig + + +class CompletionAppConfig(EasyUIBasedAppConfig): + """ + Completion App Config Entity. + """ + pass + + +class CompletionAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, + config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: AppModelConfig, + config_dict: Optional[dict] = None) -> CompletionAppConfig: + """ + Convert app model config to completion app config + :param app_model: app model + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + + app_config = CompletionAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for completion app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.COMPLETION + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, + config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # more_like_this + config, current_related_config_keys = MoreLikeThisConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py similarity index 74% rename from api/core/app/completion/app_runner.py rename to api/core/app/apps/completion/app_runner.py index ab2f40ad9a877a..8f0f191d454f84 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,10 +1,12 @@ import logging +from typing import cast from core.app.app_queue_manager import AppQueueManager -from core.app.base_app_runner import AppRunner +from core.app.apps.completion.app_config_manager import CompletionAppConfig +from core.app.apps.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - ApplicationGenerateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, ) from core.model_manager import ModelInstance from core.moderation.base import ModerationException @@ -20,7 +22,7 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, + def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, message: Message) -> None: """ @@ -30,12 +32,13 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(CompletionAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -47,8 +50,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -58,8 +61,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Include: prompt template, inputs, query(optional), files(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -70,15 +73,15 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -86,7 +89,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -98,7 +101,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # get context from datasets context = None - if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, app_record.id, @@ -107,18 +110,18 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, application_generate_entity.invoke_from ) - dataset_config = app_orchestration_config.dataset + dataset_config = app_config.dataset if dataset_config and dataset_config.retrieve_config.query_variable: query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - model_config=app_orchestration_config.model_config, + model_config=application_generate_entity.model_config, config=dataset_config, query=query, invoke_from=application_generate_entity.invoke_from, - show_retrieve_source=app_orchestration_config.show_retrieve_source, + show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback ) @@ -127,8 +130,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional), external data, dataset context(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -147,19 +150,19 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit self.recale_llm_max_tokens( - model_config=app_orchestration_config.model_config, + model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) # Invoke model model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=application_generate_entity.model_config.parameters, stop=stop, stream=application_generate_entity.stream, user=application_generate_entity.user_id, diff --git a/api/core/app/apps/workflow/__init__.py b/api/core/app/apps/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py new file mode 100644 index 00000000000000..35da72b63e0043 --- /dev/null +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -0,0 +1,71 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import AppMode, App +from models.workflow import Workflow + + +class WorkflowAppConfig(WorkflowUIBasedAppConfig): + """ + Workflow App Config Entity. + """ + pass + + +class WorkflowAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: + features_dict = workflow.features_dict + + app_config = WorkflowAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=features_dict + ), + variables=WorkflowVariablesConfigManager.convert( + workflow=workflow + ), + additional_features=cls.convert_features(features_dict) + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for workflow app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/chat/config_validator.py b/api/core/app/chat/config_validator.py deleted file mode 100644 index adb8408e285013..00000000000000 --- a/api/core/app/chat/config_validator.py +++ /dev/null @@ -1,82 +0,0 @@ -from core.app.validators.dataset_retrieval import DatasetValidator -from core.app.validators.external_data_fetch import ExternalDataFetchValidator -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.model_validator import ModelValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.opening_statement import OpeningStatementValidator -from core.app.validators.prompt import PromptValidator -from core.app.validators.retriever_resource import RetrieverResourceValidator -from core.app.validators.speech_to_text import SpeechToTextValidator -from core.app.validators.suggested_questions import SuggestedQuestionsValidator -from core.app.validators.text_to_speech import TextToSpeechValidator -from core.app.validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class ChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for chat app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.CHAT - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # dataset_query_variable - config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/app/completion/config_validator.py b/api/core/app/completion/config_validator.py deleted file mode 100644 index 7cc35efd64ac62..00000000000000 --- a/api/core/app/completion/config_validator.py +++ /dev/null @@ -1,67 +0,0 @@ -from core.app.validators.dataset_retrieval import DatasetValidator -from core.app.validators.external_data_fetch import ExternalDataFetchValidator -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.model_validator import ModelValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.more_like_this import MoreLikeThisValidator -from core.app.validators.prompt import PromptValidator -from core.app.validators.text_to_speech import TextToSpeechValidator -from core.app.validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class CompletionAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for completion app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.COMPLETION - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # dataset_query_variable - config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # more_like_this - config, current_related_config_keys = MoreLikeThisValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py new file mode 100644 index 00000000000000..fae9044fc34f1e --- /dev/null +++ b/api/core/app/entities/app_invoke_entities.py @@ -0,0 +1,111 @@ +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.entities.provider_configuration import ProviderModelBundle +from core.file.file_obj import FileObj +from core.model_runtime.entities.model_entities import AIModelEntity + + +class InvokeFrom(Enum): + """ + Invoke From. + """ + SERVICE_API = 'service-api' + WEB_APP = 'web-app' + EXPLORE = 'explore' + DEBUGGER = 'debugger' + + @classmethod + def value_of(cls, value: str) -> 'InvokeFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid invoke from value {value}') + + def to_source(self) -> str: + """ + Get source of invoke from. + + :return: source + """ + if self == InvokeFrom.WEB_APP: + return 'web_app' + elif self == InvokeFrom.DEBUGGER: + return 'dev' + elif self == InvokeFrom.EXPLORE: + return 'explore_app' + elif self == InvokeFrom.SERVICE_API: + return 'api' + + return 'dev' + + +class EasyUIBasedModelConfigEntity(BaseModel): + """ + Model Config Entity. + """ + provider: str + model: str + model_schema: AIModelEntity + mode: str + provider_model_bundle: ProviderModelBundle + credentials: dict[str, Any] = {} + parameters: dict[str, Any] = {} + stop: list[str] = [] + + +class EasyUIBasedAppGenerateEntity(BaseModel): + """ + EasyUI Based Application Generate Entity. + """ + task_id: str + + # app config + app_config: EasyUIBasedAppConfig + model_config: EasyUIBasedModelConfigEntity + + conversation_id: Optional[str] = None + inputs: dict[str, str] + query: Optional[str] = None + files: list[FileObj] = [] + user_id: str + # extras + stream: bool + invoke_from: InvokeFrom + + # extra parameters, like: auto_generate_conversation_name + extras: dict[str, Any] = {} + + +class WorkflowUIBasedAppGenerateEntity(BaseModel): + """ + Workflow UI Based Application Generate Entity. + """ + task_id: str + + # app config + app_config: WorkflowUIBasedAppConfig + + inputs: dict[str, str] + files: list[FileObj] = [] + user_id: str + # extras + stream: bool + invoke_from: InvokeFrom + + # extra parameters + extras: dict[str, Any] = {} + + +class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity): + conversation_id: Optional[str] = None + query: str diff --git a/api/core/entities/queue_entities.py b/api/core/app/entities/queue_entities.py similarity index 100% rename from api/core/entities/queue_entities.py rename to api/core/app/entities/queue_entities.py diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index fd516e465ff38c..19ff94de5e8d58 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d8ae7adcac5d47..ec316248a27afe 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage @@ -8,7 +8,7 @@ class HostingModerationFeature: - def check(self, application_generate_entity: ApplicationGenerateEntity, + def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -16,8 +16,7 @@ def check(self, application_generate_entity: ApplicationGenerateEntity, :param prompt_messages: prompt messages :return: """ - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - model_config = app_orchestration_config.model_config + model_config = application_generate_entity.model_config text = "" for prompt_message in prompt_messages: diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index dc6ea2db7924c5..359369ef59a9b5 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -7,8 +7,8 @@ from pydantic import BaseModel from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom -from core.entities.queue_entities import ( +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, @@ -58,7 +58,7 @@ class GenerateTaskPipeline: GenerateTaskPipeline is a class that generate stream output and state management for Application. """ - def __init__(self, application_generate_entity: ApplicationGenerateEntity, + def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -75,7 +75,7 @@ def __init__(self, application_generate_entity: ApplicationGenerateEntity, self._message = message self._task_state = TaskState( llm_result=LLMResult( - model=self._application_generate_entity.app_orchestration_config_entity.model_config.model, + model=self._application_generate_entity.model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() @@ -127,7 +127,7 @@ def _process_blocking_response(self) -> dict: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.app_orchestration_config_entity.model_config + model_config = self._application_generate_entity.model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -210,7 +210,7 @@ def _process_stream_response(self) -> Generator: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.app_orchestration_config_entity.model_config + model_config = self._application_generate_entity.model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -569,7 +569,7 @@ def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMess :return: """ prompts = [] - if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat': + if self._application_generate_entity.model_config.mode == 'chat': for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' @@ -638,13 +638,13 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: Init output moderation. :return: """ - app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity - sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance if sensitive_word_avoidance: return OutputModeration( - tenant_id=self._application_generate_entity.tenant_id, - app_id=self._application_generate_entity.app_id, + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, rule=ModerationRule( type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config diff --git a/api/core/app/validators/external_data_fetch.py b/api/core/app/validators/external_data_fetch.py deleted file mode 100644 index 5910aa17e76fd8..00000000000000 --- a/api/core/app/validators/external_data_fetch.py +++ /dev/null @@ -1,39 +0,0 @@ - -from core.external_data_tool.factory import ExternalDataToolFactory - - -class ExternalDataFetchValidator: - @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for external data fetch feature - - :param tenant_id: workspace id - :param config: app model config args - """ - if not config.get("external_data_tools"): - config["external_data_tools"] = [] - - if not isinstance(config["external_data_tools"], list): - raise ValueError("external_data_tools must be of list type") - - for tool in config["external_data_tools"]: - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - - if not tool["enabled"]: - continue - - if "type" not in tool or not tool["type"]: - raise ValueError("external_data_tools[].type is required") - - typ = tool["type"] - config = tool["config"] - - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) - - return config, ["external_data_tools"] diff --git a/api/core/app/validators/user_input_form.py b/api/core/app/validators/user_input_form.py deleted file mode 100644 index 249d6745ae7bf3..00000000000000 --- a/api/core/app/validators/user_input_form.py +++ /dev/null @@ -1,61 +0,0 @@ -import re - - -class UserInputFormValidator: - @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for user input form - - :param config: app model config args - """ - if not config.get("user_input_form"): - config["user_input_form"] = [] - - if not isinstance(config["user_input_form"], list): - raise ValueError("user_input_form must be a list of objects") - - variables = [] - for item in config["user_input_form"]: - key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: - raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") - - form_item = item[key] - if 'label' not in form_item: - raise ValueError("label is required in user_input_form") - - if not isinstance(form_item["label"], str): - raise ValueError("label in user_input_form must be of string type") - - if 'variable' not in form_item: - raise ValueError("variable is required in user_input_form") - - if not isinstance(form_item["variable"], str): - raise ValueError("variable in user_input_form must be of string type") - - pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") - if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") - - variables.append(form_item["variable"]) - - if 'required' not in form_item or not form_item["required"]: - form_item["required"] = False - - if not isinstance(form_item["required"], bool): - raise ValueError("required in user_input_form must be of boolean type") - - if key == "select": - if 'options' not in form_item or not form_item["options"]: - form_item["options"] = [] - - if not isinstance(form_item["options"], list): - raise ValueError("options in user_input_form must be a list of strings") - - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: - raise ValueError("default value in user_input_form must be in the options list") - - return config, ["user_input_form"] diff --git a/api/core/app/workflow/config_validator.py b/api/core/app/workflow/config_validator.py deleted file mode 100644 index e8381146a75501..00000000000000 --- a/api/core/app/workflow/config_validator.py +++ /dev/null @@ -1,39 +0,0 @@ -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.text_to_speech import TextToSpeechValidator - - -class WorkflowAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: - """ - Validate for workflow app model config - - :param tenant_id: tenant id - :param config: app model config args - :param only_structure_validate: only validate the structure of the config - """ - related_config_keys = [] - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate - ) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py deleted file mode 100644 index 8a340a8b815111..00000000000000 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ /dev/null @@ -1,262 +0,0 @@ -import json -import logging -import time -from typing import Any, Optional, Union, cast - -from langchain.agents import openai_functions_agent, openai_functions_multi_agent -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - -from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.callback_handler.entity.agent_loop import AgentLoop -from core.entities.application_entities import ModelConfigEntity -from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from extensions.ext_database import db -from models.model import Message, MessageAgentThought, MessageChain - - -class AgentLoopGatherCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - raise_error: bool = True - - def __init__(self, model_config: ModelConfigEntity, - queue_manager: AppQueueManager, - message: Message, - message_chain: MessageChain) -> None: - """Initialize callback handler.""" - self.model_config = model_config - self.queue_manager = queue_manager - self.message = message - self.message_chain = message_chain - model_type_instance = self.model_config.provider_model_bundle.model_type_instance - self.model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - @property - def agent_loops(self) -> list[AgentLoop]: - return self._agent_loops - - def clear_agent_loops(self) -> None: - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return True - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return True - - def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None: - if not self._current_loop: - # Agent start with a LLM query - self._current_loop = AgentLoop( - position=len(self._agent_loops) + 1, - prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]), - status='llm_started', - started_at=time.perf_counter() - ) - - def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None: - if self._current_loop and self._current_loop.status == 'llm_started': - self._current_loop.status = 'llm_end' - if result.usage: - self._current_loop.prompt_tokens = result.usage.prompt_tokens - else: - self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens( - model=self.model_config.model, - credentials=self.model_config.credentials, - prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)] - ) - - completion_message = result.message - if completion_message.tool_calls: - self._current_loop.completion \ - = json.dumps({'function_call': completion_message.tool_calls}) - else: - self._current_loop.completion = completion_message.content - - if result.usage: - self._current_loop.completion_tokens = result.usage.completion_tokens - else: - self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens( - model=self.model_config.model, - credentials=self.model_config.credentials, - prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)] - ) - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - pass - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - logging.debug("Agent on_llm_error: %s", error) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - # kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '} - # input_str='action-input' - # serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'} - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = json.dumps({"query": action.tool_input} - if isinstance(action.tool_input, str) else action.tool_input) - completion = None - if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \ - or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction): - thought = action.log.strip() - completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']}) - else: - action_name_position = action.log.index("Action:") if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - - if self._current_loop and self._current_loop.status == 'llm_end': - self._current_loop.status = 'agent_action' - self._current_loop.thought = thought - self._current_loop.tool_name = tool - self._current_loop.tool_input = tool_input - if completion is not None: - self._current_loop.completion = completion - - self._message_agent_thought = self._init_agent_thought() - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - # kwargs={'name': 'Search'} - # llm_prefix='Thought:' - # observation_prefix='Observation: ' - # output='53 years' - - if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None': - self._current_loop.status = 'tool_end' - self._current_loop.tool_output = output - self._current_loop.completed = True - self._current_loop.completed_at = time.perf_counter() - self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at - - self._complete_agent_thought(self._message_agent_thought) - - self._agent_loops.append(self._current_loop) - self._current_loop = None - self._message_agent_thought = None - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - logging.debug("Agent on_tool_error: %s", error) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - # Final Answer - if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'): - self._current_loop.status = 'agent_finish' - self._current_loop.completed = True - self._current_loop.completed_at = time.perf_counter() - self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at - self._current_loop.thought = '[DONE]' - self._message_agent_thought = self._init_agent_thought() - - self._complete_agent_thought(self._message_agent_thought) - - self._agent_loops.append(self._current_loop) - self._current_loop = None - self._message_agent_thought = None - elif not self._current_loop and self._agent_loops: - self._agent_loops[-1].status = 'agent_finish' - - def _init_agent_thought(self) -> MessageAgentThought: - message_agent_thought = MessageAgentThought( - message_id=self.message.id, - message_chain_id=self.message_chain.id, - position=self._current_loop.position, - thought=self._current_loop.thought, - tool=self._current_loop.tool_name, - tool_input=self._current_loop.tool_input, - message=self._current_loop.prompt, - message_price_unit=0, - answer=self._current_loop.completion, - answer_price_unit=0, - created_by_role=('account' if self.message.from_source == 'console' else 'end_user'), - created_by=(self.message.from_account_id - if self.message.from_source == 'console' else self.message.from_end_user_id) - ) - - db.session.add(message_agent_thought) - db.session.commit() - - self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER) - - return message_agent_thought - - def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: - loop_message_tokens = self._current_loop.prompt_tokens - loop_answer_tokens = self._current_loop.completion_tokens - - # transform usage - llm_usage = self.model_type_instance._calc_response_usage( - self.model_config.model, - self.model_config.credentials, - loop_message_tokens, - loop_answer_tokens - ) - - message_agent_thought.observation = self._current_loop.tool_output - message_agent_thought.tool_process_data = '' # currently not support - message_agent_thought.message_token = loop_message_tokens - message_agent_thought.message_unit_price = llm_usage.prompt_unit_price - message_agent_thought.message_price_unit = llm_usage.prompt_price_unit - message_agent_thought.answer_token = loop_answer_tokens - message_agent_thought.answer_unit_price = llm_usage.completion_unit_price - message_agent_thought.answer_price_unit = llm_usage.completion_price_unit - message_agent_thought.latency = self._current_loop.latency - message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens - message_agent_thought.total_price = llm_usage.total_price - message_agent_thought.currency = llm_usage.currency - db.session.commit() diff --git a/api/core/callback_handler/entity/agent_loop.py b/api/core/callback_handler/entity/agent_loop.py deleted file mode 100644 index 56634bb19e4990..00000000000000 --- a/api/core/callback_handler/entity/agent_loop.py +++ /dev/null @@ -1,23 +0,0 @@ -from pydantic import BaseModel - - -class AgentLoop(BaseModel): - position: int = 1 - - thought: str = None - tool_name: str = None - tool_input: str = None - tool_output: str = None - - prompt: str = None - prompt_tokens: int = 0 - completion: str = None - completion_tokens: int = 0 - - latency: float = None - - status: str = 'llm_started' - completed: bool = False - - started_at: float = None - completed_at: float = None \ No newline at end of file diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index e49a09d4c4b7ba..ca781a55bc9c90 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,6 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 64c7d1e859532f..8601cb34e79582 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -5,7 +5,7 @@ from flask import Flask, current_app -from core.entities.application_entities import ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory logger = logging.getLogger(__name__) diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 435074f7430f4c..bd896719c21835 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -3,6 +3,7 @@ from pydantic import BaseModel +from core.app.app_config.entities import FileUploadEntity from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db @@ -50,7 +51,7 @@ class FileObj(BaseModel): transfer_method: FileTransferMethod url: Optional[str] upload_file_id: Optional[str] - file_config: dict + file_upload_entity: FileUploadEntity @property def data(self) -> Optional[str]: @@ -63,7 +64,7 @@ def preview_url(self) -> Optional[str]: @property def prompt_message_content(self) -> ImagePromptMessageContent: if self.type == FileType.IMAGE: - image_config = self.file_config.get('image') + image_config = self.file_upload_entity.image_config return ImagePromptMessageContent( data=self.data, diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index c13207357820ee..9d122c41204308 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Union import requests +from core.app.app_config.entities import FileUploadEntity from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType from extensions.ext_database import db from models.account import Account -from models.model import AppModelConfig, EndUser, MessageFile, UploadFile +from models.model import EndUser, MessageFile, UploadFile from services.file_service import IMAGE_EXTENSIONS @@ -15,18 +16,16 @@ def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, + def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity: FileUploadEntity, user: Union[Account, EndUser]) -> list[FileObj]: """ validate and transform files arg :param files: - :param app_model_config: + :param file_upload_entity: :param user: :return: """ - file_upload_config = app_model_config.file_upload_dict - for file in files: if not isinstance(file, dict): raise ValueError('Invalid file format, must be dict') @@ -45,17 +44,17 @@ def validate_and_transform_files_arg(self, files: list[dict], app_model_config: raise ValueError('Missing file upload_file_id') # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_config) + type_file_objs = self._to_file_objs(files, file_upload_entity) # validate files new_files = [] for file_type, file_objs in type_file_objs.items(): if file_type == FileType.IMAGE: # parse and validate files - image_config = file_upload_config.get('image') + image_config = file_upload_entity.image_config # check if image file feature is enabled - if not image_config['enabled']: + if not image_config: continue # Validate number of files @@ -96,27 +95,27 @@ def validate_and_transform_files_arg(self, files: list[dict], app_model_config: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_upload_config: Optional[dict]) -> list[FileObj]: + def transform_message_files(self, files: list[MessageFile], file_upload_entity: FileUploadEntity) -> list[FileObj]: """ transform message files :param files: - :param file_upload_config: + :param file_upload_entity: :return: """ # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_config) + type_file_objs = self._to_file_objs(files, file_upload_entity) # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_upload_config: dict) -> dict[FileType, list[FileObj]]: + file_upload_entity: FileUploadEntity) -> dict[FileType, list[FileObj]]: """ transform files to file objs :param files: - :param file_upload_config: + :param file_upload_entity: :return: """ type_file_objs: dict[FileType, list[FileObj]] = { @@ -133,7 +132,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], if file.belongs_to == FileBelongsTo.ASSISTANT.value: continue - file_obj = self._to_file_obj(file, file_upload_config) + file_obj = self._to_file_obj(file, file_upload_entity) if file_obj.type not in type_file_objs: continue @@ -141,7 +140,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj: + def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileUploadEntity) -> FileObj: """ transform file to file obj @@ -156,7 +155,7 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) transfer_method=transfer_method, url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - file_config=file_upload_config + file_upload_entity=file_upload_entity ) else: return FileObj( @@ -166,7 +165,7 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, upload_file_id=file.upload_file_id or None, - file_config=file_upload_config + file_upload_entity=file_upload_entity ) def _check_image_remote_url(self, url): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 86d6b498da35e7..bff9b9cf1fab4b 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,7 +1,7 @@ import logging import random -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from extensions.ext_hosting_provider import hosting_configuration @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -def check_moderation(model_config: ModelConfigEntity, text: str) -> bool: +def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config if (moderation_config and moderation_config.enabled is True and 'openai' in hosting_configuration.provider_map diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 00813faef7ed84..4fe150e9833df9 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,3 +1,5 @@ +from core.app.app_config.entities import FileUploadEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ( @@ -43,12 +45,18 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, for message in messages: files = message.message_files if files: - file_objs = message_file_parser.transform_message_files( - files, - message.app_model_config.file_upload_dict - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - else message.workflow_run.workflow.features_dict.get('file_upload', {}) - ) + if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + file_upload_entity = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + else: + file_upload_entity = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict) + + if file_upload_entity: + file_objs = message_file_parser.transform_message_files( + files, + file_upload_entity + ) + else: + file_objs = [] if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 2129c58d8d2923..8fbc0c2d5003f6 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,6 +1,6 @@ import logging -from core.entities.application_entities import AppOrchestrationConfigEntity +from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationException from core.moderation.factory import ModerationFactory @@ -10,22 +10,22 @@ class InputModeration: def check(self, app_id: str, tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, + app_config: AppConfig, inputs: dict, query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity + :param app_config: app config :param inputs: inputs :param query: query :return: """ - if not app_orchestration_config_entity.sensitive_word_avoidance: + if not app_config.sensitive_word_avoidance: return False, inputs, query - sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance + sensitive_word_avoidance_config = app_config.sensitive_word_avoidance moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 6d0a1d31f5ad8e..129c2a4cd273ea 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,10 +1,7 @@ from typing import Optional -from core.entities.application_entities import ( - AdvancedCompletionPromptTemplateEntity, - ModelConfigEntity, - PromptTemplateEntity, -) +from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -31,7 +28,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: prompt_messages = [] model_mode = ModelMode.value_of(model_config.mode) @@ -65,7 +62,7 @@ def _get_completion_model_prompt_messages(self, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -113,7 +110,7 @@ def _get_chat_model_prompt_messages(self, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -202,7 +199,7 @@ def _set_histories_variable(self, memory: TokenBufferMemory, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: ModelConfigEntity) -> dict: + model_config: EasyUIBasedModelConfigEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9c554140b7b16b..7fe8128a492d63 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,6 +1,6 @@ from typing import Optional, cast -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey @@ -10,14 +10,14 @@ class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index af7b695bb33718..faf1f888e2b649 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -3,10 +3,8 @@ import os from typing import Optional -from core.entities.application_entities import ( - ModelConfigEntity, - PromptTemplateEntity, -) +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -54,7 +52,7 @@ def get_prompt(self, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> \ + model_config: EasyUIBasedModelConfigEntity) -> \ tuple[list[PromptMessage], Optional[list[str]]]: model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: @@ -83,7 +81,7 @@ def get_prompt(self, return prompt_messages, stops def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, pre_prompt: str, inputs: dict, query: Optional[str] = None, @@ -164,7 +162,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + model_config: EasyUIBasedModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -202,7 +200,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + model_config: EasyUIBasedModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( diff --git a/api/core/rag/retrieval/agent/agent_llm_callback.py b/api/core/rag/retrieval/agent/agent_llm_callback.py deleted file mode 100644 index 5ec549de8ee5b1..00000000000000 --- a/api/core/rag/retrieval/agent/agent_llm_callback.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -from typing import Optional - -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class AgentLLMCallback(Callback): - - def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None: - self.agent_callback = agent_callback - - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_before_invoke( - prompt_messages=prompt_messages - ) - - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - pass - - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_after_invoke( - result=result - ) - - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_error( - error=ex - ) diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py index 087b7bfa2c34e3..9b115bc696c009 100644 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -5,19 +5,17 @@ from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance -from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.fake_llm import FakeLLM class LLMChain(LCLLMChain): - model_config: ModelConfigEntity + model_config: EasyUIBasedModelConfigEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") parameters: dict[str, Any] = {} - agent_llm_callback: Optional[AgentLLMCallback] = None def generate( self, @@ -38,7 +36,6 @@ def generate( prompt_messages=prompt_messages, stream=False, stop=stop, - callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None, model_parameters=self.parameters ) diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 41a0c54041f8e9..84e2b0228fd7ff 100644 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -10,7 +10,7 @@ from langchain.tools import BaseTool from pydantic import root_validator -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool @@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ An Multi Dataset Retrieve Agent driven by Router. """ - model_config: ModelConfigEntity + model_config: EasyUIBasedModelConfigEntity class Config: """Configuration for this pydantic object.""" @@ -156,7 +156,7 @@ async def aplan( @classmethod def from_llm_and_tools( cls, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, diff --git a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index 4d7d33038bd8be..700bf0c293c8f1 100644 --- a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -206,7 +206,7 @@ def _construct_scratchpad( @classmethod def from_llm_and_tools( cls, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py index 7fabf71bedc3a1..749e603c5ce52e 100644 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -7,13 +7,12 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Extra +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError -from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent @@ -23,15 +22,14 @@ class AgentConfiguration(BaseModel): strategy: PlanningStrategy - model_config: ModelConfigEntity + model_config: EasyUIBasedModelConfigEntity tools: list[BaseTool] - summary_model_config: Optional[ModelConfigEntity] = None + summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None memory: Optional[TokenBufferMemory] = None callbacks: Callbacks = None max_iterations: int = 6 max_execution_time: Optional[float] = None early_stopping_method: str = "generate" - agent_llm_callback: Optional[AgentLLMCallback] = None # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit class Config: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 21e16c4162171d..8f1221adc79bd6 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -2,9 +2,10 @@ from langchain.tools import BaseTool +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -17,7 +18,7 @@ class DatasetRetrieval: def retrieve(self, tenant_id: str, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, config: DatasetEntity, query: str, invoke_from: InvokeFrom, diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 629ed2361341b8..80062e606a8e38 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -2,8 +2,9 @@ from langchain.tools import BaseTool +from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8c335f201fb590..49eea603dc7758 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,9 +8,9 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') - model_config = application_generate_entity.app_orchestration_config_entity.model_config + model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle provider_configuration = provider_model_bundle.configuration @@ -43,7 +43,7 @@ def handle(sender, **kwargs): if used_quota is not None: db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.tenant_id, + Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 69b3a90e441659..d49e560a6705bd 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,10 +9,10 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.tenant_id, - Provider.provider_name == application_generate_entity.app_orchestration_config_entity.model_config.provider + Provider.tenant_id == application_generate_entity.app_config.tenant_id, + Provider.provider_name == application_generate_entity.model_config.provider ).update({'last_used': datetime.utcnow()}) db.session.commit() diff --git a/api/models/model.py b/api/models/model.py index e514ea729bd2aa..f8f9a0a3cddc6e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -105,6 +105,18 @@ def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant + @property + def is_agent(self) -> bool: + app_model_config = self.app_model_config + if not app_model_config: + return False + if not app_model_config.agent_mode: + return False + if self.app_model_config.agent_mode_dict.get('enabled', False) \ + and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + return True + return False + @property def deleted_tools(self) -> list: # get agent mode tools diff --git a/api/models/workflow.py b/api/models/workflow.py index ff4e944e29e2f1..f9c906b85c864c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -129,7 +129,7 @@ def graph_dict(self): def features_dict(self): return self.features if not self.features else json.loads(self.features) - def user_input_form(self): + def user_input_form(self) -> list: # get start node from graph if not self.graph: return [] diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index f2caeb14ff5217..c84f6fbf454daf 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,6 +1,6 @@ -from core.app.agent_chat.config_validator import AgentChatAppConfigValidator -from core.app.chat.config_validator import ChatAppConfigValidator -from core.app.completion.config_validator import CompletionAppConfigValidator +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from models.model import AppMode @@ -9,10 +9,10 @@ class AppModelConfigService: @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: - return ChatAppConfigValidator.config_validate(tenant_id, config) + return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: - return AgentChatAppConfigValidator.config_validate(tenant_id, config) + return AgentChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.COMPLETION: - return CompletionAppConfigValidator.config_validate(tenant_id, config) + return CompletionAppConfigManager.config_validate(tenant_id, config) else: raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 8a9639e521aded..453194feb1b7aa 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -4,9 +4,9 @@ from sqlalchemy import and_ -from core.app.app_manager import AppManager -from core.app.validators.model_validator import ModelValidator -from core.entities.application_entities import InvokeFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_manager import EasyUIBasedAppManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message @@ -30,7 +30,7 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, auto_generate_name = args['auto_generate_name'] \ if 'auto_generate_name' in args else True - if app_model.mode != 'completion': + if app_model.mode != AppMode.COMPLETION.value: if not query: raise ValueError('query is required') @@ -43,6 +43,7 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, conversation_id = args['conversation_id'] if 'conversation_id' in args else None conversation = None + app_model_config_dict = None if conversation_id: conversation_filter = [ Conversation.id == args['conversation_id'], @@ -63,42 +64,13 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, if conversation.status != 'normal': raise ConversationCompletedError() - if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() - if not app_model_config: - raise AppModelConfigBrokenError() - else: - conversation_override_model_configs = json.loads(conversation.override_model_configs) - - app_model_config = AppModelConfig( - id=conversation.app_model_config_id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) - - if is_model_config_override: - # build new app model config - if 'model' not in args['model_config']: - raise ValueError('model_config.model is required') - - if 'completion_params' not in args['model_config']['model']: - raise ValueError('model_config.model.completion_params is required') - - completion_params = ModelValidator.validate_model_completion_params( - cp=args['model_config']['model']['completion_params'] - ) - - app_model_config_model = app_model_config.model_dict - app_model_config_model['completion_params'] = completion_params - app_model_config.retriever_resource = json.dumps({'enabled': True}) - - app_model_config = app_model_config.copy() - app_model_config.model = json.dumps(app_model_config_model) + if not app_model_config: + raise AppModelConfigBrokenError() else: if app_model.app_model_config_id is None: raise AppModelConfigBrokenError() @@ -113,37 +85,29 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, raise Exception("Only account can override model config") # validate config - model_config = AppModelConfigService.validate_configuration( + app_model_config_dict = AppModelConfigService.validate_configuration( tenant_id=app_model.tenant_id, config=args['model_config'], app_mode=AppMode.value_of(app_model.mode) ) - app_model_config = AppModelConfig( - id=app_model_config.id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(model_config) - - # clean input by app_model_config form rules - inputs = cls.get_cleaned_inputs(inputs, app_model_config) - # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - app_model_config, - user - ) + file_upload_entity = FileUploadConfigManager.convert(app_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] - application_manager = AppManager() + application_manager = EasyUIBasedAppManager() return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=is_model_config_override, + app_model=app_model, + app_model_config=app_model_config, + app_model_config_dict=app_model_config_dict, user=user, invoke_from=invoke_from, inputs=inputs, @@ -189,17 +153,19 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.transform_message_files( - message.files, app_model_config - ) + file_upload_entity = FileUploadConfigManager.convert(current_app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.transform_message_files( + message.files, file_upload_entity + ) + else: + file_objs = [] - application_manager = AppManager() + application_manager = EasyUIBasedAppManager() return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, + app_model=app_model, + app_model_config=current_app_model_config, app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=True, user=user, invoke_from=invoke_from, inputs=message.inputs, @@ -212,46 +178,3 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], } ) - @classmethod - def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - - # Filter input variables from form configuration, handle required fields, default values, and option values - input_form_config = app_model_config.user_input_form_list - for config in input_form_config: - input_config = list(config.values())[0] - variable = input_config["variable"] - - input_type = list(config.keys())[0] - - if variable not in user_inputs or not user_inputs[variable]: - if input_type == "external_data_tool": - continue - if "required" in input_config and input_config["required"]: - raise ValueError(f"{variable} is required in input form") - else: - filtered_inputs[variable] = input_config["default"] if "default" in input_config else "" - continue - - value = user_inputs[variable] - - if value: - if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - - if input_type == "select": - options = input_config["options"] if "options" in input_config else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: - if 'max_length' in input_config: - max_length = input_config['max_length'] - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - filtered_inputs[variable] = value.replace('\x00', '') if value else None - - return filtered_inputs diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 6c0182dd9e5bfd..d62f1980143b04 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,16 +1,9 @@ import json from typing import Optional -from core.app.app_manager import AppManager -from core.entities.application_entities import ( - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - ModelConfigEntity, - PromptTemplateEntity, - VariableEntity, -) +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ + DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, FileUploadEntity +from core.app.app_manager import EasyUIBasedAppManager from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -36,7 +29,7 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: - basic mode of chatbot app - - advanced mode of assistant app + - expert mode of chatbot app - completion app @@ -86,14 +79,11 @@ def convert_app_model_config_to_workflow(self, app_model: App, # get new app mode new_app_mode = self._get_new_app_mode(app_model) - app_model_config_dict = app_model_config.to_dict() - # convert app model config - application_manager = AppManager() - app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( - tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config_dict, - skip_check=True + application_manager = EasyUIBasedAppManager() + app_config = application_manager.convert_to_app_config( + app_model=app_model, + app_model_config=app_model_config ) # init workflow graph @@ -113,27 +103,27 @@ def convert_app_model_config_to_workflow(self, app_model: App, # convert to start node start_node = self._convert_to_start_node( - variables=app_orchestration_config_entity.variables + variables=app_config.variables ) graph['nodes'].append(start_node) # convert to http request node - if app_orchestration_config_entity.external_data_variables: + if app_config.external_data_variables: http_request_nodes = self._convert_to_http_request_node( app_model=app_model, - variables=app_orchestration_config_entity.variables, - external_data_variables=app_orchestration_config_entity.external_data_variables + variables=app_config.variables, + external_data_variables=app_config.external_data_variables ) for http_request_node in http_request_nodes: graph = self._append_node(graph, http_request_node) # convert to knowledge retrieval node - if app_orchestration_config_entity.dataset: + if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=app_orchestration_config_entity.dataset + dataset_config=app_config.dataset ) if knowledge_retrieval_node: @@ -143,9 +133,9 @@ def convert_app_model_config_to_workflow(self, app_model: App, llm_node = self._convert_to_llm_node( new_app_mode=new_app_mode, graph=graph, - model_config=app_orchestration_config_entity.model_config, - prompt_template=app_orchestration_config_entity.prompt_template, - file_upload=app_orchestration_config_entity.file_upload + model_config=app_config.model, + prompt_template=app_config.prompt_template, + file_upload=app_config.additional_features.file_upload ) graph = self._append_node(graph, llm_node) @@ -155,6 +145,8 @@ def convert_app_model_config_to_workflow(self, app_model: App, graph = self._append_node(graph, end_node) + app_model_config_dict = app_config.app_model_config_dict + # features if new_app_mode == AppMode.ADVANCED_CHAT: features = { diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 102c8617331b11..c9efd056ff2b98 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,8 +2,8 @@ from datetime import datetime from typing import Optional -from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator -from core.app.workflow.config_validator import WorkflowAppConfigValidator +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -162,13 +162,13 @@ def convert_to_workflow(self, app_model: App, account: Account) -> App: def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: - return AdvancedChatAppConfigValidator.config_validate( + return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: - return WorkflowAppConfigValidator.config_validate( + return WorkflowAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 69acb23681f17f..4357c6405c8a3d 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,8 +2,8 @@ import pytest -from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ - ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity +from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ + ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole @@ -137,11 +137,11 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="https://example.com/image1.jpg", - file_config={ - "image": { + file_upload_entity=FileUploadEntity( + image_config={ "detail": "high", } - } + ) ) ] diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 8a260b05072c1b..9796fc5558110f 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from core.entities.application_entities import ModelConfigEntity +from core.app.app_config.entities import ModelConfigEntity from core.entities.provider_configuration import ProviderModelBundle from core.model_runtime.entities.message_entities import UserPromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index a95a6dc52f524c..70f6070c6bbc3e 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform @@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p(): def test__get_chat_model_prompt_messages(): - model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' @@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): - model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index d4edc73410eac3..0ca8ae135ce8dc 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -4,7 +4,7 @@ import pytest -from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity from core.helper import encrypter From 2eaae6742a9ad9e450a854a886cca544880f01b7 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 02:40:26 +0800 Subject: [PATCH 052/160] lint fix --- api/core/agent/base_agent_runner.py | 7 ++++--- api/core/agent/cot_agent_runner.py | 2 +- api/core/app/app_manager.py | 8 ++++---- api/core/memory/token_buffer_memory.py | 1 - api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/core/tools/tool/dataset_retriever_tool.py | 2 +- api/services/workflow/workflow_converter.py | 11 +++++++++-- 8 files changed, 21 insertions(+), 14 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 529240aecb7b30..f22ca7653f2ccf 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -9,12 +9,13 @@ from core.app.app_queue_manager import AppQueueManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, - InvokeFrom, EasyUIBasedModelConfigEntity, + EasyUIBasedModelConfigEntity, + InvokeFrom, ) +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 5b345f4da0c6a6..8b444ef3be3f9b 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,8 +4,8 @@ from typing import Literal, Union from core.agent.base_agent_runner import BaseAgentRunner -from core.app.app_queue_manager import PublishFrom from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit +from core.app.app_queue_manager import PublishFrom from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py index 98ebe2c87d8ab8..ea8a97f8780415 100644 --- a/api/core/app/app_manager.py +++ b/api/core/app/app_manager.py @@ -9,26 +9,26 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, EasyUIBasedAppConfig, VariableEntity +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, VariableEntity +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner -from core.app.generate_task_pipeline import GenerateTaskPipeline from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, InvokeFrom, ) +from core.app.generate_task_pipeline import GenerateTaskPipeline from core.file.file_obj import FileObj from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile logger = logging.getLogger(__name__) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 4fe150e9833df9..471400f09baffc 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,4 +1,3 @@ -from core.app.app_config.entities import FileUploadEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 129c2a4cd273ea..cdd03b85f1ff3c 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8f1221adc79bd6..37581f1e92fee8 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -3,9 +3,9 @@ from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity, InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy -from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 80062e606a8e38..1522d3af092cb6 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -3,8 +3,8 @@ from langchain.tools import BaseTool from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index d62f1980143b04..b3061cc2552d67 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,8 +1,15 @@ import json from typing import Optional -from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ - DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, FileUploadEntity +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, +) from core.app.app_manager import EasyUIBasedAppManager from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode From b80092ea1243fc039a45d85b04f1744f9c20db49 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 02:40:31 +0800 Subject: [PATCH 053/160] lint fix --- api/core/agent/entities.py | 2 +- api/core/app/app_config/base_app_config_manager.py | 7 ++++--- .../easy_ui_based_app/model_config/converter.py | 1 - .../easy_ui_based_app/model_config/manager.py | 2 +- .../easy_ui_based_app/prompt_template/manager.py | 7 +++++-- .../app_config/easy_ui_based_app/variables/manager.py | 5 ++--- .../app_config/features/opening_statement/manager.py | 3 +-- api/core/app/apps/advanced_chat/app_config_manager.py | 7 ++++--- api/core/app/apps/agent_chat/app_config_manager.py | 11 ++++++----- api/core/app/apps/base_app_runner.py | 9 +++++---- api/core/app/apps/chat/app_config_manager.py | 9 +++++---- api/core/app/apps/chat/app_runner.py | 4 ++-- api/core/app/apps/completion/app_config_manager.py | 4 ++-- api/core/app/apps/completion/app_runner.py | 4 ++-- api/core/app/apps/workflow/app_config_manager.py | 2 +- 15 files changed, 41 insertions(+), 36 deletions(-) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 0fbfdc26367237..e7016d6030cc20 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, Any, Union, Optional +from typing import Any, Literal, Optional, Union from pydantic import BaseModel diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index b3c773203d7fbf..e09aa0376685a1 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Optional, Union from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -6,8 +6,9 @@ from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from models.model import AppModelConfig diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 05fcb107919673..610e9bce3215e0 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -2,7 +2,6 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity - from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.model_entities import ModelType diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5cca2bc1a74be5..730a9527cf7315 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,5 +1,5 @@ from core.app.app_config.entities import ModelConfigEntity -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.provider_manager import ProviderManager diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 5629d0d09e4537..1f410758aa41da 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,8 @@ -from core.app.app_config.entities import PromptTemplateEntity, \ - AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.app.app_config.entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) from core.model_runtime.entities.message_entities import PromptMessageRole from core.prompt.simple_prompt_transform import ModelMode from models.model import AppMode diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index ff962a5439512a..1237da502b4258 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,13 +1,12 @@ import re -from typing import Tuple -from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity from core.external_data_tool.factory import ExternalDataToolFactory class BasicVariablesConfigManager: @classmethod - def convert(cls, config: dict) -> Tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: """ Convert model config to model config diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 6183c6e7493311..0d8a71bfcf4d8b 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,9 +1,8 @@ -from typing import Tuple class OpeningStatementConfigManager: @classmethod - def convert(cls, config: dict) -> Tuple[str, list]: + def convert(cls, config: dict) -> tuple[str, list]: """ Convert model config to model config diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index ab7857c4adad7d..d0909ead70d585 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -5,11 +5,12 @@ from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import AppMode, App +from models.model import App, AppMode from models.workflow import Workflow diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 96dac4bd010f8e..55a04832aa598c 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -3,22 +3,23 @@ from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager -from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, DatasetEntity +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode, App, AppModelConfig +from models.model import App, AppMode, AppModelConfig OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 93f819af0883ed..64c1a464918142 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -2,14 +2,15 @@ from collections.abc import Generator from typing import Optional, Union, cast -from core.app.app_config.entities import PromptTemplateEntity, ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature -from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, - InvokeFrom, EasyUIBasedModelConfigEntity, + EasyUIBasedModelConfigEntity, + InvokeFrom, ) +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 62b2aaae5a5b2b..ff0195563efb3a 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,20 +1,21 @@ from typing import Optional from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import AppMode, App, AppModelConfig +from models.model import App, AppMode, AppModelConfig class ChatAppConfig(EasyUIBasedAppConfig): diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 403a2d4476e304..1b256f11c4a7f5 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -2,12 +2,12 @@ from typing import cast from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, ) +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index b920f369b5e93f..6bdb7cc4b3875a 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,16 +1,16 @@ from typing import Optional from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import AppMode, App, AppModelConfig +from models.model import App, AppMode, AppModelConfig class CompletionAppConfig(EasyUIBasedAppConfig): diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 8f0f191d454f84..d60e14aaeb8597 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -2,12 +2,12 @@ from typing import cast from core.app.app_queue_manager import AppQueueManager -from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.apps.base_app_runner import AppRunner -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, ) +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 35da72b63e0043..194339a23b54cf 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -4,7 +4,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import AppMode, App +from models.model import App, AppMode from models.workflow import Workflow From 06b05163f673f533e48810e72eef9a1e60cf89f4 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 15:53:40 +0800 Subject: [PATCH 054/160] update app import response --- api/controllers/console/app/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 98636fa95f3e03..db23a028cd63a3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -76,7 +76,7 @@ class AppImportApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_fields) + @marshal_with(app_detail_fields_with_site) @cloud_edition_billing_resource_check('apps') def post(self): """Import app""" From 09dfe80718d85fef7e31a2547d174b2af6355fd1 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 15:57:34 +0800 Subject: [PATCH 055/160] add app copy api --- api/controllers/console/app/app.py | 29 ++++++++++++++++++++++++++++- api/services/app_service.py | 5 +++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index db23a028cd63a3..7b2411b96fd4a9 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -93,7 +93,7 @@ def post(self): args = parser.parse_args() app_service = AppService() - app = app_service.import_app(current_user.current_tenant_id, args, current_user) + app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user) return app, 201 @@ -180,6 +180,32 @@ def delete(self, app_model): return {'result': 'success'}, 204 +class AppCopyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_detail_fields_with_site) + def post(self, app_model): + """Copy app""" + # The role of the current user in the ta table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + app_service = AppService() + data = app_service.export_app(app_model) + app = app_service.import_app(current_user.current_tenant_id, data, args, current_user) + + return app, 201 + + class AppExportApi(Resource): @setup_required @login_required @@ -266,6 +292,7 @@ def post(self, app_model): api.add_resource(AppListApi, '/apps') api.add_resource(AppImportApi, '/apps/import') api.add_resource(AppApi, '/apps/') +api.add_resource(AppCopyApi, '/apps//copy') api.add_resource(AppExportApi, '/apps//export') api.add_resource(AppNameApi, '/apps//name') api.add_resource(AppIconApi, '/apps//icon') diff --git a/api/services/app_service.py b/api/services/app_service.py index e0a7835cb72a5d..f1d0e3df19393b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -124,15 +124,16 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: return app - def import_app(self, tenant_id: str, args: dict, account: Account) -> App: + def import_app(self, tenant_id: str, data: str, args: dict, account: Account) -> App: """ Import app :param tenant_id: tenant id + :param data: import data :param args: request args :param account: Account instance """ try: - import_data = yaml.safe_load(args['data']) + import_data = yaml.safe_load(data) except yaml.YAMLError as e: raise ValueError("Invalid YAML format in data argument.") From e498efce2d79587628bcb8c904af2843971e8549 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 3 Mar 2024 04:18:38 +0800 Subject: [PATCH 056/160] refactor app generate --- api/controllers/console/app/completion.py | 6 +- api/core/agent/base_agent_runner.py | 13 +- .../model_config/converter.py | 8 +- api/core/app/app_manager.py | 468 ------------------ .../apps/advanced_chat/app_config_manager.py | 8 +- .../app/apps/agent_chat/app_config_manager.py | 25 +- api/core/app/apps/agent_chat/app_generator.py | 194 ++++++++ api/core/app/apps/agent_chat/app_runner.py | 7 +- api/core/app/apps/base_app_generator.py | 42 ++ api/core/app/apps/base_app_runner.py | 13 +- api/core/app/apps/chat/app_config_manager.py | 25 +- api/core/app/apps/chat/app_generator.py | 194 ++++++++ api/core/app/apps/chat/app_runner.py | 4 +- .../app/apps/completion/app_config_manager.py | 21 +- api/core/app/apps/completion/app_generator.py | 292 +++++++++++ api/core/app/apps/completion/app_runner.py | 4 +- .../app/apps/message_based_app_generator.py | 251 ++++++++++ .../app/apps/workflow/app_config_manager.py | 2 +- api/core/app/entities/app_invoke_entities.py | 72 ++- .../hosting_moderation/hosting_moderation.py | 2 +- api/core/app/generate_task_pipeline.py | 18 +- api/core/helper/moderation.py | 4 +- api/core/prompt/advanced_prompt_transform.py | 10 +- api/core/prompt/prompt_transform.py | 6 +- api/core/prompt/simple_prompt_transform.py | 10 +- api/core/rag/retrieval/agent/llm_chain.py | 4 +- .../agent/multi_dataset_router_agent.py | 6 +- .../structed_multi_dataset_router_agent.py | 4 +- .../retrieval/agent_based_dataset_executor.py | 6 +- api/core/rag/retrieval/dataset_retrieval.py | 4 +- .../deduct_quota_when_messaeg_created.py | 4 +- ...vider_last_used_at_when_messaeg_created.py | 4 +- api/services/completion_service.py | 209 ++------ api/services/workflow/workflow_converter.py | 39 +- .../prompt/test_simple_prompt_transform.py | 6 +- 35 files changed, 1235 insertions(+), 750 deletions(-) delete mode 100644 api/core/app/app_manager.py create mode 100644 api/core/app/apps/agent_chat/app_generator.py create mode 100644 api/core/app/apps/base_app_generator.py create mode 100644 api/core/app/apps/chat/app_generator.py create mode 100644 api/core/app/apps/completion/app_generator.py create mode 100644 api/core/app/apps/message_based_app_generator.py diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index ed1522c0cdf891..fd6cfadfef8575 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -59,8 +59,7 @@ def post(self, app_model): user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming, - is_model_config_override=True + streaming=streaming ) return compact_response(response) @@ -126,8 +125,7 @@ def post(self, app_model): user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming, - is_model_config_override=True + streaming=streaming ) return compact_response(response) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f22ca7653f2ccf..ef530b9122237b 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -10,9 +10,8 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, - EasyUIBasedModelConfigEntity, - InvokeFrom, + ModelConfigWithCredentialsEntity, + InvokeFrom, AgentChatAppGenerateEntity, ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -49,9 +48,9 @@ class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, - application_generate_entity: EasyUIBasedAppGenerateEntity, + application_generate_entity: AgentChatAppGenerateEntity, app_config: AgentChatAppConfig, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, config: AgentEntity, queue_manager: AppQueueManager, message: Message, @@ -123,8 +122,8 @@ def __init__(self, tenant_id: str, else: self.stream_tool_call = False - def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \ - -> EasyUIBasedAppGenerateEntity: + def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ + -> AgentChatAppGenerateEntity: """ Repack app generate entity """ diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 610e9bce3215e0..5c9b2cfec7babf 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,7 +1,7 @@ from typing import cast from core.app.app_config.entities import EasyUIBasedAppConfig -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.model_entities import ModelType @@ -9,11 +9,11 @@ from core.provider_manager import ProviderManager -class EasyUIBasedModelConfigEntityConverter: +class ModelConfigConverter: @classmethod def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) \ - -> EasyUIBasedModelConfigEntity: + -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -91,7 +91,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") - return EasyUIBasedModelConfigEntity( + return ModelConfigWithCredentialsEntity( provider=model_config.provider, model=model_config.model, model_schema=model_schema, diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py deleted file mode 100644 index ea8a97f8780415..00000000000000 --- a/api/core/app/app_manager.py +++ /dev/null @@ -1,468 +0,0 @@ -import json -import logging -import threading -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union, cast - -from flask import Flask, current_app -from pydantic import ValidationError - -from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter -from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, VariableEntity -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom -from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager -from core.app.apps.agent_chat.app_runner import AgentChatAppRunner -from core.app.apps.chat.app_config_manager import ChatAppConfigManager -from core.app.apps.chat.app_runner import ChatAppRunner -from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.app.apps.completion.app_runner import CompletionAppRunner -from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, - InvokeFrom, -) -from core.app.generate_task_pipeline import GenerateTaskPipeline -from core.file.file_obj import FileObj -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from extensions.ext_database import db -from models.account import Account -from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile - -logger = logging.getLogger(__name__) - - -class EasyUIBasedAppManager: - - def generate(self, app_model: App, - app_model_config: AppModelConfig, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - inputs: dict[str, str], - app_model_config_dict: Optional[dict] = None, - query: Optional[str] = None, - files: Optional[list[FileObj]] = None, - conversation: Optional[Conversation] = None, - stream: bool = False, - extras: Optional[dict[str, Any]] = None) \ - -> Union[dict, Generator]: - """ - Generate App response. - - :param app_model: App - :param app_model_config: app model config - :param user: account or end user - :param invoke_from: invoke from source - :param inputs: inputs - :param app_model_config_dict: app model config dict - :param query: query - :param files: file obj list - :param conversation: conversation - :param stream: is stream - :param extras: extras - """ - # init task id - task_id = str(uuid.uuid4()) - - # convert to app config - app_config = self.convert_to_app_config( - app_model=app_model, - app_model_config=app_model_config, - app_model_config_dict=app_model_config_dict, - conversation=conversation - ) - - # init application generate entity - application_generate_entity = EasyUIBasedAppGenerateEntity( - task_id=task_id, - app_config=app_config, - model_config=EasyUIBasedModelConfigEntityConverter.convert(app_config), - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), - query=query.replace('\x00', '') if query else None, - files=files if files else [], - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras - ) - - if not stream and application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: - raise ValueError("Agent app is not supported in blocking mode.") - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = AppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) - - worker_thread.start() - - # return response or stream generator - return self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - stream=stream - ) - - def convert_to_app_config(self, app_model: App, - app_model_config: AppModelConfig, - app_model_config_dict: Optional[dict] = None, - conversation: Optional[Conversation] = None) -> EasyUIBasedAppConfig: - if app_model_config_dict: - config_from = EasyUIBasedAppModelConfigFrom.ARGS - elif conversation: - config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG - else: - config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT.value - app_config = AgentChatAppConfigManager.config_convert( - app_model=app_model, - config_from=config_from, - app_model_config=app_model_config, - config_dict=app_model_config_dict - ) - elif app_mode == AppMode.CHAT: - app_config = ChatAppConfigManager.config_convert( - app_model=app_model, - config_from=config_from, - app_model_config=app_model_config, - config_dict=app_model_config_dict - ) - elif app_mode == AppMode.COMPLETION: - app_config = CompletionAppConfigManager.config_convert( - app_model=app_model, - config_from=config_from, - app_model_config=app_model_config, - config_dict=app_model_config_dict - ) - else: - raise ValueError("Invalid app mode") - - return app_config - - def _get_cleaned_inputs(self, user_inputs: dict, app_config: EasyUIBasedAppConfig): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - - # Filter input variables from form configuration, handle required fields, default values, and option values - variables = app_config.variables - for variable_config in variables: - variable = variable_config.variable - - if variable not in user_inputs or not user_inputs[variable]: - if variable_config.required: - raise ValueError(f"{variable} is required in input form") - else: - filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" - continue - - value = user_inputs[variable] - - if value: - if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - - if variable_config.type == VariableEntity.Type.SELECT: - options = variable_config.options if variable_config.options is not None else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: - if variable_config.max_length is not None: - max_length = variable_config.max_length - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - filtered_inputs[variable] = value.replace('\x00', '') if value else None - - return filtered_inputs - - def _generate_worker(self, flask_app: Flask, - application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: - """ - Generate worker in a new thread. - :param flask_app: Flask app - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation_id: conversation ID - :param message_id: message ID - :return: - """ - with flask_app.app_context(): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - if application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: - # agent app - runner = AgentChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - elif application_generate_entity.app_config.app_mode == AppMode.CHAT: - # chatbot app - runner = ChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - elif application_generate_entity.app_config.app_mode == AppMode.COMPLETION: - # completion app - runner = CompletionAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - message=message - ) - else: - raise ValueError("Invalid app mode") - except ConversationTaskStoppedException: - pass - except InvokeAuthorizationError: - queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER - ) - except ValidationError as e: - logger.exception("Validation Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except Exception as e: - logger.exception("Unknown Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.remove() - - def _handle_response(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: - """ - Handle response. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param stream: is stream - :return: - """ - # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - - try: - return generate_task_pipeline.process(stream=stream) - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() - else: - logger.exception(e) - raise e - finally: - db.session.remove() - - def _init_generate_records(self, application_generate_entity: EasyUIBasedAppGenerateEntity) \ - -> tuple[Conversation, Message]: - """ - Initialize generate records - :param application_generate_entity: application generate entity - :return: - """ - model_type_instance = application_generate_entity.model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_schema = model_type_instance.get_model_schema( - model=application_generate_entity.model_config.model, - credentials=application_generate_entity.model_config.credentials - ) - - app_config = application_generate_entity.app_config - - app_record = (db.session.query(App) - .filter(App.id == app_config.app_id).first()) - - app_mode = app_record.mode - - # get from source - end_user_id = None - account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' - end_user_id = application_generate_entity.user_id - else: - from_source = 'console' - account_id = application_generate_entity.user_id - - override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS: - override_model_configs = app_config.app_model_config_dict - - introduction = '' - if app_mode == 'chat': - # get conversation introduction - introduction = self._get_conversation_introduction(application_generate_entity) - - if not application_generate_entity.conversation_id: - conversation = Conversation( - app_id=app_record.id, - app_model_config_id=app_config.app_model_config_id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_mode, - name='New conversation', - inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status='normal', - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - ) - - db.session.add(conversation) - db.session.commit() - else: - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == application_generate_entity.conversation_id, - Conversation.app_id == app_record.id - ).first() - ) - - currency = model_schema.pricing.currency if model_schema.pricing else 'USD' - - message = Message( - app_id=app_record.id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query or "", - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency=currency, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - agent_based=app_config.app_mode == AppMode.AGENT_CHAT, - ) - - db.session.add(message) - db.session.commit() - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, - upload_file_id=file.upload_file_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message - - def _get_conversation_introduction(self, application_generate_entity: EasyUIBasedAppGenerateEntity) -> str: - """ - Get conversation introduction - :param application_generate_entity: application generate entity - :return: conversation introduction - """ - app_config = application_generate_entity.app_config - introduction = app_config.additional_features.opening_statement - - if introduction: - try: - inputs = application_generate_entity.inputs - prompt_template = PromptTemplateParser(template=introduction) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - introduction = prompt_template.format(prompt_inputs) - except KeyError: - pass - - return introduction - - def _get_conversation(self, conversation_id: str) -> Conversation: - """ - Get conversation by conversation id - :param conversation_id: conversation id - :return: conversation - """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) - - return conversation - - def _get_message(self, message_id: str) -> Message: - """ - Get message by message id - :param message_id: message id - :return: message - """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) - - return message diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index d0909ead70d585..72ba4c33d4e1a4 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -10,7 +12,7 @@ ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import App, AppMode +from models.model import App, AppMode, Conversation from models.workflow import Workflow @@ -23,7 +25,9 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, + workflow: Workflow, + conversation: Optional[Conversation] = None) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_config = AdvancedChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 55a04832aa598c..57214f924a6023 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -19,7 +19,7 @@ ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -33,19 +33,30 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, - config_from: EasyUIBasedAppModelConfigFrom, + def get_app_config(cls, app_model: App, app_model_config: AppModelConfig, - config_dict: Optional[dict] = None) -> AgentChatAppConfig: + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model - :param config_from: app model config from :param app_model_config: app model config - :param config_dict: app model config dict + :param conversation: conversation + :param override_config_dict: app model config dict :return: """ - config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict app_config = AgentChatAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py new file mode 100644 index 00000000000000..1ab456d8223532 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -0,0 +1,194 @@ +import logging +import threading +import uuid +from typing import Union, Any, Generator + +from flask import current_app, Flask +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, AgentChatAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser + +logger = logging.getLogger(__name__) + + +class AgentChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = AgentChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: AgentChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2f1de8f108afba..6bae5e1648188e 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -7,7 +7,8 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, \ + AgentChatAppGenerateEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -26,7 +27,7 @@ class AgentChatAppRunner(AppRunner): """ Agent Application Runner """ - def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def run(self, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -292,7 +293,7 @@ def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversation 'pool': db_variables.variables }) - def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity, + def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, message: Message) -> LLMUsage: """ Get usage of all agent thoughts diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py new file mode 100644 index 00000000000000..65764021aaf3ac --- /dev/null +++ b/api/core/app/apps/base_app_generator.py @@ -0,0 +1,42 @@ +from core.app.app_config.entities import VariableEntity, AppConfig + + +class BaseAppGenerator: + def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + # Filter input variables from form configuration, handle required fields, default values, and option values + variables = app_config.variables + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"{variable} is required in input form") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs + diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 64c1a464918142..ee70f161a27181 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,9 +5,8 @@ from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, - EasyUIBasedModelConfigEntity, - InvokeFrom, + ModelConfigWithCredentialsEntity, + InvokeFrom, AppGenerateEntity, EasyUIBasedAppGenerateEntity, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature @@ -27,7 +26,7 @@ class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -83,7 +82,7 @@ def get_pre_calculate_rest_tokens(self, app_record: App, return rest_tokens - def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, + def recale_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance @@ -119,7 +118,7 @@ def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, model_config.parameters[parameter_rule.name] = max_tokens def organize_prompt_messages(self, app_record: App, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -292,7 +291,7 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, def moderation_for_inputs(self, app_id: str, tenant_id: str, - app_generate_entity: EasyUIBasedAppGenerateEntity, + app_generate_entity: AppGenerateEntity, inputs: dict, query: str) -> tuple[bool, dict, str]: """ diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index ff0195563efb3a..ac69a928231420 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -15,7 +15,7 @@ SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation class ChatAppConfig(EasyUIBasedAppConfig): @@ -27,19 +27,30 @@ class ChatAppConfig(EasyUIBasedAppConfig): class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, - config_from: EasyUIBasedAppModelConfigFrom, + def get_app_config(cls, app_model: App, app_model_config: AppModelConfig, - config_dict: Optional[dict] = None) -> ChatAppConfig: + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model - :param config_from: app model config from :param app_model_config: app model config - :param config_dict: app model config dict + :param conversation: conversation + :param override_config_dict: app model config dict :return: """ - config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict app_config = ChatAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py new file mode 100644 index 00000000000000..712822f3a5664f --- /dev/null +++ b/api/core/app/apps/chat/app_generator.py @@ -0,0 +1,194 @@ +import logging +import threading +import uuid +from typing import Union, Any, Generator + +from flask import current_app, Flask +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser + +logger = logging.getLogger(__name__) + + +class ChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = ChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 1b256f11c4a7f5..57aca9d3e613d5 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -5,7 +5,7 @@ from core.app.apps.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, + ChatAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -23,7 +23,7 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def run(self, application_generate_entity: ChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 6bdb7cc4b3875a..77a14430373bff 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -10,7 +10,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation class CompletionAppConfig(EasyUIBasedAppConfig): @@ -22,19 +22,26 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, - config_from: EasyUIBasedAppModelConfigFrom, + def get_app_config(cls, app_model: App, app_model_config: AppModelConfig, - config_dict: Optional[dict] = None) -> CompletionAppConfig: + override_config_dict: Optional[dict] = None) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model - :param config_from: app model config from :param app_model_config: app model config - :param config_dict: app model config dict + :param override_config_dict: app model config dict :return: """ - config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict app_config = CompletionAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py new file mode 100644 index 00000000000000..d258a3bd9da7e8 --- /dev/null +++ b/api/core/app/apps/completion/app_generator.py @@ -0,0 +1,292 @@ +import json +import logging +import threading +import uuid +from typing import Union, Any, Generator + +from flask import current_app, Flask +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, CompletionAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser, Message +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class CompletionAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = {} + + # get conversation + conversation = None + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = CompletionAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + inputs=self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get message + message = self._get_message(message_id) + + # chatbot app + runner = CompletionAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def generate_more_like_this(self, app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param message_id: message ID + :param user: account or end user + :param invoke_from: invoke from source + :param stream: is stream + """ + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + override_model_config_dict = app_model_config.to_dict() + model_dict = override_model_config_dict['model'] + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + override_model_config_dict['model'] = model_dict + + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + message.files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={} + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index d60e14aaeb8597..c5b8ca6c9a1e2e 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -5,7 +5,7 @@ from core.app.apps.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, + CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance @@ -22,7 +22,7 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def run(self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message) -> None: """ diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py new file mode 100644 index 00000000000000..783c6c6ee52063 --- /dev/null +++ b/api/core/app/apps/message_based_app_generator.py @@ -0,0 +1,251 @@ +import json +import logging +from typing import Union, Generator, Optional + +from sqlalchemy import and_ + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_queue_manager import ConversationTaskStoppedException, AppQueueManager +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity, AppGenerateEntity, \ + CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity +from core.app.generate_task_pipeline import GenerateTaskPipeline +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models.account import Account +from models.model import Conversation, Message, AppMode, MessageFile, App, EndUser, AppModelConfig +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError + +logger = logging.getLogger(__name__) + + +class MessageBasedAppGenerator(BaseAppGenerator): + + def _handle_response(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = GenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() + + def _get_conversation_by_user(self, app_model: App, conversation_id: str, + user: Union[Account, EndUser]) -> Conversation: + conversation_filter = [ + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.status == 'normal' + ] + + if isinstance(user, Account): + conversation_filter.append(Conversation.from_account_id == user.id) + else: + conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) + + conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() + + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != 'normal': + raise ConversationCompletedError() + + return conversation + + def _get_app_model_config(self, app_model: App, + conversation: Optional[Conversation] = None) \ + -> AppModelConfig: + if conversation: + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() + + if not app_model_config: + raise AppModelConfigBrokenError() + else: + if app_model.app_model_config_id is None: + raise AppModelConfigBrokenError() + + app_model_config = app_model.app_model_config + + if not app_model_config: + raise AppModelConfigBrokenError() + + return app_model_config + + def _init_generate_records(self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], + conversation: Optional[Conversation] = None) \ + -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :return: + """ + app_config = application_generate_entity.app_config + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + from_source = 'api' + end_user_id = application_generate_entity.user_id + else: + from_source = 'console' + account_id = application_generate_entity.user_id + + override_model_configs = None + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ + and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + override_model_configs = app_config.app_model_config_dict + + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not conversation: + conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=app_config.app_model_config_id, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_config.app_mode.value, + name='New conversation', + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status='normal', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + + message = Message( + app_id=app_config.app_id, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency='USD', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id + ) + + db.session.add(message) + db.session.commit() + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + belongs_to='user', + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if account_id else 'end_user'), + created_by=account_id or end_user_id, + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_config = application_generate_entity.app_config + introduction = app_config.additional_features.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction + + def _get_conversation(self, conversation_id: str) -> Conversation: + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + return conversation + + def _get_message(self, message_id: str) -> Message: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = ( + db.session.query(Message) + .filter(Message.id == message_id) + .first() + ) + + return message diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 194339a23b54cf..91bab1b21896c9 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -17,7 +17,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): class WorkflowAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: features_dict = workflow.features_dict app_config = WorkflowAppConfig( diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index fae9044fc34f1e..9097345674f3bf 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig, AppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.model_entities import AIModelEntity @@ -49,9 +49,9 @@ def to_source(self) -> str: return 'dev' -class EasyUIBasedModelConfigEntity(BaseModel): +class ModelConfigWithCredentialsEntity(BaseModel): """ - Model Config Entity. + Model Config With Credentials Entity. """ provider: str model: str @@ -63,21 +63,19 @@ class EasyUIBasedModelConfigEntity(BaseModel): stop: list[str] = [] -class EasyUIBasedAppGenerateEntity(BaseModel): +class AppGenerateEntity(BaseModel): """ - EasyUI Based Application Generate Entity. + App Generate Entity. """ task_id: str # app config - app_config: EasyUIBasedAppConfig - model_config: EasyUIBasedModelConfigEntity + app_config: AppConfig - conversation_id: Optional[str] = None inputs: dict[str, str] - query: Optional[str] = None files: list[FileObj] = [] user_id: str + # extras stream: bool invoke_from: InvokeFrom @@ -86,26 +84,52 @@ class EasyUIBasedAppGenerateEntity(BaseModel): extras: dict[str, Any] = {} -class WorkflowUIBasedAppGenerateEntity(BaseModel): +class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ - Workflow UI Based Application Generate Entity. + Chat Application Generate Entity. """ - task_id: str - # app config - app_config: WorkflowUIBasedAppConfig + app_config: EasyUIBasedAppConfig + model_config: ModelConfigWithCredentialsEntity - inputs: dict[str, str] - files: list[FileObj] = [] - user_id: str - # extras - stream: bool - invoke_from: InvokeFrom + query: Optional[str] = None - # extra parameters - extras: dict[str, Any] = {} +class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Chat Application Generate Entity. + """ + conversation_id: Optional[str] = None -class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity): + +class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Completion Application Generate Entity. + """ + pass + + +class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Agent Chat Application Generate Entity. + """ conversation_id: Optional[str] = None - query: str + + +class AdvancedChatAppGenerateEntity(AppGenerateEntity): + """ + Advanced Chat Application Generate Entity. + """ + # app config + app_config: WorkflowUIBasedAppConfig + + conversation_id: Optional[str] = None + query: Optional[str] = None + + +class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity): + """ + Workflow UI Based Application Generate Entity. + """ + # app config + app_config: WorkflowUIBasedAppConfig diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index ec316248a27afe..7d555328db9717 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index 359369ef59a9b5..926b0e128c815d 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -7,7 +7,8 @@ from pydantic import BaseModel from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom, CompletionAppGenerateEntity, \ + AgentChatAppGenerateEntity from core.app.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, @@ -39,7 +40,7 @@ from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought, MessageFile +from models.model import Conversation, Message, MessageAgentThought, MessageFile, AppMode from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) @@ -58,7 +59,11 @@ class GenerateTaskPipeline: GenerateTaskPipeline is a class that generate stream output and state management for Application. """ - def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def __init__(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -433,6 +438,7 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price + self._message.currency = usage.currency db.session.commit() @@ -440,7 +446,11 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, + is_first_message=self._application_generate_entity.app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.ADVANCED_CHAT + ] and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index bff9b9cf1fab4b..20feae8554f79d 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,7 +1,7 @@ import logging import random -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from extensions.ext_hosting_provider import hosting_configuration @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool: +def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config if (moderation_config and moderation_config.enabled is True and 'openai' in hosting_configuration.provider_map diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index cdd03b85f1ff3c..48b0d8ba021e03 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,7 +1,7 @@ from typing import Optional from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -28,7 +28,7 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: prompt_messages = [] model_mode = ModelMode.value_of(model_config.mode) @@ -62,7 +62,7 @@ def _get_completion_model_prompt_messages(self, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -110,7 +110,7 @@ def _get_chat_model_prompt_messages(self, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -199,7 +199,7 @@ def _set_histories_variable(self, memory: TokenBufferMemory, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: EasyUIBasedModelConfigEntity) -> dict: + model_config: ModelConfigWithCredentialsEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 7fe8128a492d63..02e91d91128629 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,6 +1,6 @@ from typing import Optional, cast -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey @@ -10,14 +10,14 @@ class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index faf1f888e2b649..ca0efb200c15b1 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,7 +4,7 @@ from typing import Optional from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -52,7 +52,7 @@ def get_prompt(self, files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> \ + model_config: ModelConfigWithCredentialsEntity) -> \ tuple[list[PromptMessage], Optional[list[str]]]: model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: @@ -81,7 +81,7 @@ def get_prompt(self, return prompt_messages, stops def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, inputs: dict, query: Optional[str] = None, @@ -162,7 +162,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) \ + model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -200,7 +200,7 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) \ + model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py index 9b115bc696c009..f2c5d4ca33042b 100644 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -5,14 +5,14 @@ from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.rag.retrieval.agent.fake_llm import FakeLLM class LLMChain(LCLLMChain): - model_config: EasyUIBasedModelConfigEntity + model_config: ModelConfigWithCredentialsEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") parameters: dict[str, Any] = {} diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 84e2b0228fd7ff..be24731d46a394 100644 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -10,7 +10,7 @@ from langchain.tools import BaseTool from pydantic import root_validator -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool @@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ An Multi Dataset Retrieve Agent driven by Router. """ - model_config: EasyUIBasedModelConfigEntity + model_config: ModelConfigWithCredentialsEntity class Config: """Configuration for this pydantic object.""" @@ -156,7 +156,7 @@ async def aplan( @classmethod def from_llm_and_tools( cls, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, diff --git a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index 700bf0c293c8f1..7035ec8e2f7834 100644 --- a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -206,7 +206,7 @@ def _construct_scratchpad( @classmethod def from_llm_and_tools( cls, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py index 749e603c5ce52e..cb475bcffb7910 100644 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -7,7 +7,7 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Extra -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.message_entities import prompt_messages_to_lc_messages from core.helper import moderation @@ -22,9 +22,9 @@ class AgentConfiguration(BaseModel): strategy: PlanningStrategy - model_config: EasyUIBasedModelConfigEntity + model_config: ModelConfigWithCredentialsEntity tools: list[BaseTool] - summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None + summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None memory: Optional[TokenBufferMemory] = None callbacks: Callbacks = None max_iterations: int = 6 diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 37581f1e92fee8..395f2eb165e731 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -3,7 +3,7 @@ from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,7 +18,7 @@ class DatasetRetrieval: def retrieve(self, tenant_id: str, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, config: DatasetEntity, query: str, invoke_from: InvokeFrom, diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 49eea603dc7758..77d1ab0822d1dd 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,7 +8,7 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index d49e560a6705bd..eca773f3b31862 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,7 +9,7 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 453194feb1b7aa..4e3c4e19f649c6 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -1,180 +1,71 @@ -import json from collections.abc import Generator from typing import Any, Union -from sqlalchemy import and_ - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_manager import EasyUIBasedAppManager +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.message_file_parser import MessageFileParser -from extensions.ext_database import db -from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message -from services.app_model_config_service import AppModelConfigService -from services.errors.app import MoreLikeThisDisabledError -from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import MessageNotExistsError +from models.model import Account, App, AppMode, EndUser class CompletionService: @classmethod def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, - invoke_from: InvokeFrom, streaming: bool = True, - is_model_config_override: bool = False) -> Union[dict, Generator]: - # is streaming mode - inputs = args['inputs'] - query = args['query'] - files = args['files'] if 'files' in args and args['files'] else [] - auto_generate_name = args['auto_generate_name'] \ - if 'auto_generate_name' in args else True - - if app_model.mode != AppMode.COMPLETION.value: - if not query: - raise ValueError('query is required') - - if query: - if not isinstance(query, str): - raise ValueError('query must be a string') - - query = query.replace('\x00', '') - - conversation_id = args['conversation_id'] if 'conversation_id' in args else None - - conversation = None - app_model_config_dict = None - if conversation_id: - conversation_filter = [ - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id, - Conversation.status == 'normal' - ] - - if isinstance(user, Account): - conversation_filter.append(Conversation.from_account_id == user.id) - else: - conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) - - conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() - - if not conversation: - raise ConversationNotExistsError() - - if conversation.status != 'normal': - raise ConversationCompletedError() - - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() - - if not app_model_config: - raise AppModelConfigBrokenError() - else: - if app_model.app_model_config_id is None: - raise AppModelConfigBrokenError() - - app_model_config = app_model.app_model_config - - if not app_model_config: - raise AppModelConfigBrokenError() - - if is_model_config_override: - if not isinstance(user, Account): - raise Exception("Only account can override model config") - - # validate config - app_model_config_dict = AppModelConfigService.validate_configuration( - tenant_id=app_model.tenant_id, - config=args['model_config'], - app_mode=AppMode.value_of(app_model.mode) - ) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(app_model_config_dict or app_model_config.to_dict()) - if file_upload_entity: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_upload_entity, - user + invoke_from: InvokeFrom, streaming: bool = True) -> Union[dict, Generator]: + """ + App Completion + :param app_model: app model + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + if app_model.mode == AppMode.COMPLETION.value: + return CompletionAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.CHAT.value: + return ChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.AGENT_CHAT.value: + return AgentChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming ) else: - file_objs = [] - - application_manager = EasyUIBasedAppManager() - return application_manager.generate( - app_model=app_model, - app_model_config=app_model_config, - app_model_config_dict=app_model_config_dict, - user=user, - invoke_from=invoke_from, - inputs=inputs, - query=query, - files=file_objs, - conversation=conversation, - stream=streaming, - extras={ - "auto_generate_conversation_name": auto_generate_name - } - ) + raise ValueError('Invalid app mode') @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ -> Union[dict, Generator]: - if not user: - raise ValueError('user cannot be None') - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() - - if not message: - raise MessageNotExistsError() - - current_app_model_config = app_model.app_model_config - more_like_this = current_app_model_config.more_like_this_dict - - if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: - raise MoreLikeThisDisabledError() - - app_model_config = message.app_model_config - model_dict = app_model_config.model_dict - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - app_model_config.model = json.dumps(model_dict) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(current_app_model_config.to_dict()) - if file_upload_entity: - file_objs = message_file_parser.transform_message_files( - message.files, file_upload_entity - ) - else: - file_objs = [] - - application_manager = EasyUIBasedAppManager() - return application_manager.generate( + """ + Generate more like this + :param app_model: app model + :param user: user + :param message_id: message id + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + return CompletionAppGenerator().generate_more_like_this( app_model=app_model, - app_model_config=current_app_model_config, - app_model_config_dict=app_model_config.to_dict(), + message_id=message_id, user=user, invoke_from=invoke_from, - inputs=message.inputs, - query=message.query, - files=file_objs, - conversation=None, - stream=streaming, - extras={ - "auto_generate_conversation_name": False - } + stream=streaming ) - diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index b3061cc2552d67..9d377cc466337b 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -8,9 +8,11 @@ FileUploadEntity, ModelConfigEntity, PromptTemplateEntity, - VariableEntity, + VariableEntity, EasyUIBasedAppConfig, ) -from core.app.app_manager import EasyUIBasedAppManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -87,8 +89,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, new_app_mode = self._get_new_app_mode(app_model) # convert app model config - application_manager = EasyUIBasedAppManager() - app_config = application_manager.convert_to_app_config( + app_config = self._convert_to_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -190,6 +191,30 @@ def convert_app_model_config_to_workflow(self, app_model: App, return workflow + def _convert_to_app_config(self, app_model: App, + app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_model.mode = AppMode.AGENT_CHAT.value + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + elif app_mode == AppMode.CHAT: + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + elif app_mode == AppMode.COMPLETION: + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + else: + raise ValueError("Invalid app mode") + + return app_config + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: """ Convert to Start Node @@ -566,6 +591,6 @@ def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) :return: """ return db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 70f6070c6bbc3e..be9fe8d004fa32 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform @@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p(): def test__get_chat_model_prompt_messages(): - model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' @@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): - model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' From 602bc67495d62334fc7796a0a6eaeacd19e33770 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 3 Mar 2024 04:18:51 +0800 Subject: [PATCH 057/160] lint fix --- api/core/agent/base_agent_runner.py | 3 ++- api/core/app/apps/agent_chat/app_generator.py | 9 +++++---- api/core/app/apps/agent_chat/app_runner.py | 3 +-- api/core/app/apps/base_app_generator.py | 2 +- api/core/app/apps/base_app_runner.py | 4 +++- api/core/app/apps/chat/app_generator.py | 9 +++++---- .../app/apps/completion/app_config_manager.py | 2 +- api/core/app/apps/completion/app_generator.py | 10 +++++----- .../app/apps/message_based_app_generator.py | 18 ++++++++++++------ api/core/app/entities/app_invoke_entities.py | 2 +- .../hosting_moderation/hosting_moderation.py | 2 +- api/core/app/generate_task_pipeline.py | 10 +++++++--- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/services/workflow/workflow_converter.py | 3 ++- 14 files changed, 47 insertions(+), 32 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ef530b9122237b..236a5d9cf77244 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -10,8 +10,9 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + InvokeFrom, ModelConfigWithCredentialsEntity, - InvokeFrom, AgentChatAppGenerateEntity, ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 1ab456d8223532..d5dbdf0dd2cb05 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,18 +1,19 @@ import logging import threading import uuid -from typing import Union, Any, Generator +from collections.abc import Generator +from typing import Any, Union -from flask import current_app, Flask +from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, AgentChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 6bae5e1648188e..27a473fb17e3d0 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -7,8 +7,7 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, \ - AgentChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 65764021aaf3ac..750c6dae1036b5 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,4 +1,4 @@ -from core.app.app_config.entities import VariableEntity, AppConfig +from core.app.app_config.entities import AppConfig, VariableEntity class BaseAppGenerator: diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index ee70f161a27181..8de71d4bfb69f9 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,8 +5,10 @@ from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + EasyUIBasedAppGenerateEntity, + InvokeFrom, ModelConfigWithCredentialsEntity, - InvokeFrom, AppGenerateEntity, EasyUIBasedAppGenerateEntity, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 712822f3a5664f..978ac9656b64a5 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,18 +1,19 @@ import logging import threading import uuid -from typing import Union, Any, Generator +from collections.abc import Generator +from typing import Any, Union -from flask import current_app, Flask +from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 77a14430373bff..a82e68a337ab4c 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -10,7 +10,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig class CompletionAppConfig(EasyUIBasedAppConfig): diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index d258a3bd9da7e8..9355bae12381f8 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,19 +1,19 @@ -import json import logging import threading import uuid -from typing import Union, Any, Generator +from collections.abc import Generator +from typing import Any, Union -from flask import current_app, Flask +from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, CompletionAppGenerateEntity +from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 783c6c6ee52063..2fb609e615f90e 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,21 +1,27 @@ import json import logging -from typing import Union, Generator, Optional +from collections.abc import Generator +from typing import Optional, Union from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom -from core.app.app_queue_manager import ConversationTaskStoppedException, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity, AppGenerateEntity, \ - CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + AppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) from core.app.generate_task_pipeline import GenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account -from models.model import Conversation, Message, AppMode, MessageFile, App, EndUser, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError logger = logging.getLogger(__name__) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9097345674f3bf..1c4f32b8f28da1 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig, AppConfig +from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.model_entities import AIModelEntity diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index 7d555328db9717..ec316248a27afe 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index 926b0e128c815d..60dfc5cdad7e6a 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -7,8 +7,12 @@ from pydantic import BaseModel from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom, CompletionAppGenerateEntity, \ - AgentChatAppGenerateEntity +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) from core.app.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, @@ -40,7 +44,7 @@ from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought, MessageFile, AppMode +from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 395f2eb165e731..ee728423262fce 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -3,7 +3,7 @@ from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 9d377cc466337b..527c6543812a0a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -4,11 +4,12 @@ from core.app.app_config.entities import ( DatasetEntity, DatasetRetrieveConfigEntity, + EasyUIBasedAppConfig, ExternalDataVariableEntity, FileUploadEntity, ModelConfigEntity, PromptTemplateEntity, - VariableEntity, EasyUIBasedAppConfig, + VariableEntity, ) from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager From be709d4b844f870cb0457f0e58e5e74009405f9b Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:04:40 +0800 Subject: [PATCH 058/160] add AdvancedChatAppGenerateTaskPipeline --- api/core/app/app_queue_manager.py | 67 +++++- .../apps/advanced_chat/app_config_manager.py | 6 +- .../app/apps/advanced_chat/app_generator.py | 218 ++++++++++++++++++ api/core/app/apps/advanced_chat/app_runner.py | 103 +++++++++ api/core/app/apps/base_app_runner.py | 4 +- .../app/apps/message_based_app_generator.py | 38 +-- api/core/app/entities/queue_entities.py | 74 ++++-- api/core/workflow/workflow_engine_manager.py | 38 +++ .../deduct_quota_when_messaeg_created.py | 7 +- ...rsation_name_when_first_message_created.py | 3 +- ...vider_last_used_at_when_messaeg_created.py | 7 +- api/models/model.py | 6 +- api/models/workflow.py | 41 ++++ api/services/workflow_service.py | 19 +- 14 files changed, 570 insertions(+), 61 deletions(-) create mode 100644 api/core/app/apps/advanced_chat/app_generator.py create mode 100644 api/core/app/apps/advanced_chat/app_runner.py diff --git a/api/core/app/app_queue_manager.py b/api/core/app/app_queue_manager.py index 4bd491269cebb7..5655c8d979b4f2 100644 --- a/api/core/app/app_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -8,19 +8,24 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( - AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, QueueErrorEvent, + QueueLLMChunkEvent, QueueMessage, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client @@ -97,18 +102,30 @@ def stop_listen(self) -> None: """ self._q.put(None) - def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: + def publish_llm_chunk(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: """ - Publish chunk message to channel + Publish llm chunk to channel - :param chunk: chunk + :param chunk: llm chunk :param pub_from: publish from :return: """ - self.publish(QueueMessageEvent( + self.publish(QueueLLMChunkEvent( chunk=chunk ), pub_from) + def publish_text_chunk(self, text: str, pub_from: PublishFrom) -> None: + """ + Publish text chunk to channel + + :param text: text + :param pub_from: publish from + :return: + """ + self.publish(QueueTextChunkEvent( + text=text + ), pub_from) + def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: """ Publish agent chunk message to channel @@ -146,7 +163,7 @@ def publish_annotation_reply(self, message_annotation_id: str, pub_from: Publish :param pub_from: publish from :return: """ - self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) + self.publish(QueueAnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None: """ @@ -158,6 +175,42 @@ def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> N self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from) self.stop_listen() + def publish_workflow_started(self, workflow_run_id: str, pub_from: PublishFrom) -> None: + """ + Publish workflow started + :param workflow_run_id: workflow run id + :param pub_from: publish from + :return: + """ + self.publish(QueueWorkflowStartedEvent(workflow_run_id=workflow_run_id), pub_from) + + def publish_workflow_finished(self, workflow_run_id: str, pub_from: PublishFrom) -> None: + """ + Publish workflow finished + :param workflow_run_id: workflow run id + :param pub_from: publish from + :return: + """ + self.publish(QueueWorkflowFinishedEvent(workflow_run_id=workflow_run_id), pub_from) + + def publish_node_started(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: + """ + Publish node started + :param workflow_node_execution_id: workflow node execution id + :param pub_from: publish from + :return: + """ + self.publish(QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) + + def publish_node_finished(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: + """ + Publish node finished + :param workflow_node_execution_id: workflow node execution id + :param pub_from: publish from + :return: + """ + self.publish(QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) + def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None: """ Publish agent thought diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index 72ba4c33d4e1a4..3ac26ebe80c4ec 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ -from typing import Optional from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager @@ -12,7 +11,7 @@ ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import App, AppMode, Conversation +from models.model import App, AppMode from models.workflow import Workflow @@ -26,8 +25,7 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config(cls, app_model: App, - workflow: Workflow, - conversation: Optional[Conversation] = None) -> AdvancedChatAppConfig: + workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_config = AdvancedChatAppConfig( diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py new file mode 100644 index 00000000000000..ca2f4005479915 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -0,0 +1,218 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get workflow + workflow_engine_manager = WorkflowEngineManager() + if invoke_from == InvokeFrom.DEBUGGER: + workflow = workflow_engine_manager.get_draft_workflow(app_model=app_model) + else: + workflow = workflow_engine_manager.get_published_workflow(app_model=app_model) + + if not workflow: + raise ValueError('Workflow not initialized') + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py new file mode 100644 index 00000000000000..0d701ae2240d96 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -0,0 +1,103 @@ +import logging +from typing import cast + +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, +) +from core.moderation.base import ModerationException +from extensions.ext_database import db +from models.model import App, Conversation, Message + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppRunner(AppRunner): + """ + AdvancedChat Application Runner + """ + + def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(AdvancedChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + ) + except ModerationException as e: + # TODO + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + + # TODO + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream + ) + return + + # check hosting moderation + # TODO + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + # todo RUN WORKFLOW \ No newline at end of file diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 8de71d4bfb69f9..4e099c9ae1529a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -187,7 +187,7 @@ def direct_output(self, queue_manager: AppQueueManager, if stream: index = 0 for token in text: - queue_manager.publish_chunk_message(LLMResultChunk( + queue_manager.publish_llm_chunk(LLMResultChunk( model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -261,7 +261,7 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, usage = None for result in invoke_result: if not agent: - queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish_llm_chunk(result, PublishFrom.APPLICATION_MANAGER) else: queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 2fb609e615f90e..dab72bd6d6cf7c 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,14 +8,15 @@ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, AppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom, ) -from core.app.generate_task_pipeline import GenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account @@ -31,7 +32,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response(self, application_generate_entity: Union[ ChatAppGenerateEntity, CompletionAppGenerateEntity, - AgentChatAppGenerateEntity + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity ], queue_manager: AppQueueManager, conversation: Conversation, @@ -47,7 +49,7 @@ def _handle_response(self, application_generate_entity: Union[ :return: """ # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( + generate_task_pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, @@ -114,7 +116,8 @@ def _init_generate_records(self, application_generate_entity: Union[ ChatAppGenerateEntity, CompletionAppGenerateEntity, - AgentChatAppGenerateEntity + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity ], conversation: Optional[Conversation] = None) \ -> tuple[Conversation, Message]: @@ -135,10 +138,19 @@ def _init_generate_records(self, from_source = 'console' account_id = application_generate_entity.user_id - override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: - override_model_configs = app_config.app_model_config_dict + if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): + app_model_config_id = None + override_model_configs = None + model_provider = None + model_id = None + else: + app_model_config_id = app_config.app_model_config_id + model_provider = application_generate_entity.model_config.provider + model_id = application_generate_entity.model_config.model + override_model_configs = None + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ + and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + override_model_configs = app_config.app_model_config_dict # get conversation introduction introduction = self._get_conversation_introduction(application_generate_entity) @@ -146,9 +158,9 @@ def _init_generate_records(self, if not conversation: conversation = Conversation( app_id=app_config.app_id, - app_model_config_id=app_config.app_model_config_id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, + app_model_config_id=app_model_config_id, + model_provider=model_provider, + model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, name='New conversation', @@ -167,8 +179,8 @@ def _init_generate_records(self, message = Message( app_id=app_config.app_id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, + model_provider=model_provider, + model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, conversation_id=conversation.id, inputs=application_generate_entity.inputs, diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index c1f8fb7e8964a9..25bdd7d9e39c4e 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -10,14 +10,19 @@ class QueueEvent(Enum): """ QueueEvent enum """ - MESSAGE = "message" + LLM_CHUNK = "llm_chunk" + TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" - MESSAGE_REPLACE = "message-replace" - MESSAGE_END = "message-end" - RETRIEVER_RESOURCES = "retriever-resources" - ANNOTATION_REPLY = "annotation-reply" - AGENT_THOUGHT = "agent-thought" - MESSAGE_FILE = "message-file" + MESSAGE_REPLACE = "message_replace" + MESSAGE_END = "message_end" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_FINISHED = "workflow_finished" + NODE_STARTED = "node_started" + NODE_FINISHED = "node_finished" + RETRIEVER_RESOURCES = "retriever_resources" + ANNOTATION_REPLY = "annotation_reply" + AGENT_THOUGHT = "agent_thought" + MESSAGE_FILE = "message_file" ERROR = "error" PING = "ping" STOP = "stop" @@ -30,13 +35,22 @@ class AppQueueEvent(BaseModel): event: QueueEvent -class QueueMessageEvent(AppQueueEvent): +class QueueLLMChunkEvent(AppQueueEvent): """ - QueueMessageEvent entity + QueueLLMChunkEvent entity """ - event = QueueEvent.MESSAGE + event = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + +class QueueTextChunkEvent(AppQueueEvent): + """ + QueueTextChunkEvent entity + """ + event = QueueEvent.TEXT_CHUNK + chunk_text: str + + class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity @@ -61,9 +75,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): retriever_resources: list[dict] -class AnnotationReplyEvent(AppQueueEvent): +class QueueAnnotationReplyEvent(AppQueueEvent): """ - AnnotationReplyEvent entity + QueueAnnotationReplyEvent entity """ event = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -76,6 +90,38 @@ class QueueMessageEndEvent(AppQueueEvent): event = QueueEvent.MESSAGE_END llm_result: LLMResult + +class QueueWorkflowStartedEvent(AppQueueEvent): + """ + QueueWorkflowStartedEvent entity + """ + event = QueueEvent.WORKFLOW_STARTED + workflow_run_id: str + + +class QueueWorkflowFinishedEvent(AppQueueEvent): + """ + QueueWorkflowFinishedEvent entity + """ + event = QueueEvent.WORKFLOW_FINISHED + workflow_run_id: str + + +class QueueNodeStartedEvent(AppQueueEvent): + """ + QueueNodeStartedEvent entity + """ + event = QueueEvent.NODE_STARTED + workflow_node_execution_id: str + + +class QueueNodeFinishedEvent(AppQueueEvent): + """ + QueueNodeFinishedEvent entity + """ + event = QueueEvent.NODE_FINISHED + workflow_node_execution_id: str + class QueueAgentThoughtEvent(AppQueueEvent): """ @@ -84,13 +130,15 @@ class QueueAgentThoughtEvent(AppQueueEvent): event = QueueEvent.AGENT_THOUGHT agent_thought_id: str + class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ event = QueueEvent.MESSAGE_FILE message_file_id: str - + + class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index e69de29bb2d1d6..f7955a87e8cf1b 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -0,0 +1,38 @@ +from typing import Optional + +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowEngineManager: + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() + + # return draft workflow + return workflow + + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get published workflow + """ + if not app_model.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id + ).first() + + # return published workflow + return workflow diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 77d1ab0822d1dd..53cbb2ecdc96ce 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,7 +8,10 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get('application_generate_entity') + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index f5f3ba2540d7aa..31535bf4ef68fb 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,6 +1,7 @@ from core.llm_generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db +from models.model import AppMode @message_was_created.connect @@ -15,7 +16,7 @@ def handle(sender, **kwargs): auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) if auto_generate_conversation_name and is_first_message: - if conversation.mode == 'chat': + if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app if not app_model: return diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index eca773f3b31862..ae983cc5d1a537 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,7 +9,10 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get('application_generate_entity') + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, diff --git a/api/models/model.py b/api/models/model.py index f8f9a0a3cddc6e..c579c3dee83399 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -451,10 +451,10 @@ class Conversation(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - app_model_config_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) + app_model_config_id = db.Column(UUID, nullable=True) + model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - model_id = db.Column(db.String(255), nullable=False) + model_id = db.Column(db.String(255), nullable=True) mode = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) diff --git a/api/models/workflow.py b/api/models/workflow.py index f9c906b85c864c..2540d334025c4e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -272,6 +272,10 @@ def created_by_end_user(self): return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + @property + def outputs_dict(self): + return self.outputs if not self.outputs else json.loads(self.outputs) + class WorkflowNodeExecutionTriggeredFrom(Enum): """ @@ -294,6 +298,28 @@ def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': raise ValueError(f'invalid workflow node execution triggered from value {value}') +class WorkflowNodeExecutionStatus(Enum): + """ + Workflow Node Execution Status Enum + """ + RUNNING = 'running' + SUCCEEDED = 'succeeded' + FAILED = 'failed' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow node execution status value {value}') + + class WorkflowNodeExecution(db.Model): """ Workflow Node Execution @@ -387,6 +413,21 @@ def created_by_end_user(self): return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + @property + def inputs_dict(self): + return self.inputs if not self.inputs else json.loads(self.inputs) + + @property + def outputs_dict(self): + return self.outputs if not self.outputs else json.loads(self.outputs) + + @property + def process_data_dict(self): + return self.process_data if not self.process_data else json.loads(self.process_data) + + @property + def execution_metadata_dict(self): + return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata) class WorkflowAppLog(db.Model): """ diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c9efd056ff2b98..13ea67d343b6b3 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -4,6 +4,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -21,15 +22,10 @@ def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow """ - # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow_engine_manager = WorkflowEngineManager() # return draft workflow - return workflow + return workflow_engine_manager.get_draft_workflow(app_model=app_model) def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ @@ -38,15 +34,10 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow_engine_manager = WorkflowEngineManager() # return published workflow - return workflow + return workflow_engine_manager.get_published_workflow(app_model=app_model) def sync_draft_workflow(self, app_model: App, graph: dict, From e9004a06a563b92a45df16dbadd99a3855378cfc Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:04:46 +0800 Subject: [PATCH 059/160] lint fix --- .../advanced_chat/generate_task_pipeline.py | 563 ++++++++++++++++++ .../easy_ui_based_generate_task_pipeline.py} | 43 +- 2 files changed, 585 insertions(+), 21 deletions(-) create mode 100644 api/core/app/apps/advanced_chat/generate_task_pipeline.py rename api/core/app/{generate_task_pipeline.py => apps/easy_ui_based_generate_task_pipeline.py} (95%) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py new file mode 100644 index 00000000000000..d443435fc18beb --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -0,0 +1,563 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from pydantic import BaseModel + +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.tools.tool_file_manager import ToolFileManager +from events.message_event import message_was_created +from extensions.ext_database import db +from models.model import Conversation, Message, MessageFile +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus +from services.annotation_service import AppAnnotationService + +logger = logging.getLogger(__name__) + + +class TaskState(BaseModel): + """ + TaskState entity + """ + answer: str = "" + metadata: dict = {} + + +class AdvancedChatAppGenerateTaskPipeline: + """ + AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + """ + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._conversation = conversation + self._message = message + self._task_state = TaskState( + usage=LLMUsage.empty_usage() + ) + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + + def process(self, stream: bool) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + if stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == 'llm': # todo use enum + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueWorkflowFinishedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + # Save message + self._save_message() + + response = { + 'event': 'message', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'mode': self._conversation.mode, + 'answer': self._task_state.answer, + 'metadata': {}, + 'created_at': int(self._message.created_at.timestamp()) + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == 'llm': # todo use enum + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueWorkflowFinishedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_price': workflow_run.total_price, + 'currency': workflow_run.currency, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + yield self._yield_response(workflow_run_response) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': self._task_state.answer, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(replace_response) + + # Save message + self._save_message() + + response = { + 'event': 'message_end', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + yield self._yield_response(response) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueMessageFileEvent): + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter(MessageFile.id == event.message_file_id) + .first() + ) + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + + if message_file: + response = { + 'event': 'message_file', + 'conversation_id': self._conversation.id, + 'id': message_file.id, + 'type': message_file.type, + 'belongs_to': message_file.belongs_to or 'user', + 'url': url + } + + yield self._yield_response(response) + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.chunk_text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': event.text, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Get workflow run. + :param workflow_run_id: workflow run id + :return: + """ + return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: + """ + Get workflow node execution. + :param workflow_node_execution_id: workflow node execution id + :return: + """ + return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + + def _save_message(self) -> None: + """ + Save message. + :return: + """ + self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + + self._message.answer = self._task_state.answer + self._message.provider_response_latency = time.perf_counter() - self._start_at + + if self._task_state.metadata and self._task_state.metadata.get('usage'): + usage = LLMUsage(**self._task_state.metadata['usage']) + + self._message.message_tokens = usage.prompt_tokens + self._message.message_unit_price = usage.prompt_unit_price + self._message.message_price_unit = usage.prompt_price_unit + self._message.answer_tokens = usage.completion_tokens + self._message.answer_unit_price = usage.completion_unit_price + self._message.answer_price_unit = usage.completion_price_unit + self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.total_price = usage.total_price + self._message.currency = usage.currency + + db.session.commit() + + message_was_created.send( + self._message, + application_generate_entity=self._application_generate_entity, + conversation=self._conversation, + is_first_message=self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras + ) + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'message', + 'id': self._message.id, + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': text, + 'created_at': int(self._message.created_at.timestamp()) + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + **data + } + + def _get_response_metadata(self) -> dict: + """ + Get response metadata by invoke from. + :return: + """ + metadata = {} + + # show_retrieve_source + if 'retriever_resources' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] + else: + metadata['retriever_resources'] = [] + for resource in self._task_state.metadata['retriever_resources']: + metadata['retriever_resources'].append({ + 'segment_id': resource['segment_id'], + 'position': resource['position'], + 'document_name': resource['document_name'], + 'score': resource['score'], + 'content': resource['content'], + }) + # show annotation reply + if 'annotation_reply' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] + + # show usage + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['usage'] = self._task_state.metadata['usage'] + + return metadata + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + on_message_replace_func=self._queue_manager.publish_message_replace + ) diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py similarity index 95% rename from api/core/app/generate_task_pipeline.py rename to api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 60dfc5cdad7e6a..80596668b8680f 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -14,12 +14,12 @@ InvokeFrom, ) from core.app.entities.queue_entities import ( - AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, QueueErrorEvent, + QueueLLMChunkEvent, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, @@ -40,6 +40,7 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created @@ -58,9 +59,9 @@ class TaskState(BaseModel): metadata: dict = {} -class GenerateTaskPipeline: +class EasyUIBasedGenerateTaskPipeline: """ - GenerateTaskPipeline is a class that generate stream output and state management for Application. + EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: Union[ @@ -79,12 +80,13 @@ def __init__(self, application_generate_entity: Union[ :param message: message """ self._application_generate_entity = application_generate_entity + self._model_config = application_generate_entity.model_config self._queue_manager = queue_manager self._conversation = conversation self._message = message self._task_state = TaskState( llm_result=LLMResult( - model=self._application_generate_entity.model_config.model, + model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() @@ -119,7 +121,7 @@ def _process_blocking_response(self) -> dict: raise self._handle_error(event) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -136,7 +138,7 @@ def _process_blocking_response(self) -> dict: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -193,7 +195,7 @@ def _process_blocking_response(self) -> dict: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -219,7 +221,7 @@ def _process_stream_response(self) -> Generator: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -272,7 +274,7 @@ def _process_stream_response(self) -> Generator: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: replace_response['conversation_id'] = self._conversation.id yield self._yield_response(replace_response) @@ -287,7 +289,7 @@ def _process_stream_response(self) -> Generator: 'message_id': self._message.id, } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -296,7 +298,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -334,7 +336,7 @@ def _process_stream_response(self) -> Generator: 'message_files': agent_thought.files } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -365,12 +367,12 @@ def _process_stream_response(self) -> Generator: 'url': url } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) - elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): + elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: @@ -383,7 +385,7 @@ def _process_stream_response(self) -> Generator: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_chunk_message(LLMResultChunk( + self._queue_manager.publish_llm_chunk(LLMResultChunk( model=self._task_state.llm_result.model, prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( @@ -411,7 +413,7 @@ def _process_stream_response(self) -> Generator: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -452,8 +454,7 @@ def _save_message(self, llm_result: LLMResult) -> None: conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in [ AppMode.AGENT_CHAT, - AppMode.CHAT, - AppMode.ADVANCED_CHAT + AppMode.CHAT ] and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) @@ -473,7 +474,7 @@ def _handle_chunk(self, text: str, agent: bool = False) -> dict: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id return response @@ -583,7 +584,7 @@ def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMess :return: """ prompts = [] - if self._application_generate_entity.model_config.mode == 'chat': + if self._model_config.mode == ModelMode.CHAT.value: for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' From d9b8a938c6a68ea4cdbdbcb9c01333e356eafe08 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:05:47 +0800 Subject: [PATCH 060/160] use enum instead --- api/core/app/apps/advanced_chat/generate_task_pipeline.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index d443435fc18beb..2aa649afeab0ff 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -30,6 +30,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.NodeEntities import NodeType from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageFile @@ -111,7 +112,7 @@ def _process_blocking_response(self) -> dict: elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == 'llm': # todo use enum + if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict usage_dict = outputs.get('usage', {}) self._task_state.metadata['usage'] = usage_dict @@ -201,7 +202,7 @@ def _process_stream_response(self) -> Generator: elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == 'llm': # todo use enum + if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict usage_dict = outputs.get('usage', {}) self._task_state.metadata['usage'] = usage_dict From 75559bcbf90168ab4cf5f0b04881b0e4b01d6835 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:06:27 +0800 Subject: [PATCH 061/160] replace block type to node type --- api/core/workflow/entities/NodeEntities.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/NodeEntities.py index d72b000dfb876d..80471cc7024fb5 100644 --- a/api/core/workflow/entities/NodeEntities.py +++ b/api/core/workflow/entities/NodeEntities.py @@ -19,14 +19,14 @@ class NodeType(Enum): VARIABLE_ASSIGNER = 'variable-assigner' @classmethod - def value_of(cls, value: str) -> 'BlockType': + def value_of(cls, value: str) -> 'NodeType': """ - Get value of given block type. + Get value of given node type. - :param value: block type value - :return: block type + :param value: node type value + :return: node type """ - for block_type in cls: - if block_type.value == value: - return block_type - raise ValueError(f'invalid block type value {value}') + for node_type in cls: + if node_type.value == value: + return node_type + raise ValueError(f'invalid node type value {value}') From df809ff435c155510121c2e083a477b9fc13e28e Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 13:21:24 +0800 Subject: [PATCH 062/160] add get default node config --- api/controllers/console/app/app.py | 2 +- api/controllers/console/app/workflow.py | 35 ++++++++- .../advanced_chat/generate_task_pipeline.py | 2 +- .../{NodeEntities.py => node_entities.py} | 0 api/core/workflow/nodes/base_node.py | 12 ++++ api/core/workflow/nodes/code/__init__.py | 0 api/core/workflow/nodes/code/code_node.py | 64 +++++++++++++++++ .../workflow/nodes/direct_answer/__init__.py | 0 .../nodes/direct_answer/direct_answer_node.py | 5 ++ api/core/workflow/nodes/end/end_node.py | 5 ++ .../workflow/nodes/http_request/__init__.py | 0 .../nodes/http_request/http_request_node.py | 5 ++ api/core/workflow/nodes/if_else/__init__.py | 0 .../workflow/nodes/if_else/if_else_node.py | 5 ++ .../nodes/knowledge_retrieval/__init__.py | 0 .../knowledge_retrieval_node.py | 5 ++ api/core/workflow/nodes/llm/__init__.py | 0 api/core/workflow/nodes/llm/llm_node.py | 40 +++++++++++ .../nodes/question_classifier/__init__.py | 0 .../question_classifier_node.py | 19 +++++ api/core/workflow/nodes/start/__init__.py | 0 api/core/workflow/nodes/start/start_node.py | 5 ++ .../nodes/template_transform/__init__.py | 0 .../template_transform_node.py | 25 +++++++ api/core/workflow/nodes/tool/__init__.py | 0 api/core/workflow/nodes/tool/tool_node.py | 5 ++ .../nodes/variable_assigner/__init__.py | 0 .../variable_assigner_node.py | 5 ++ api/core/workflow/workflow_engine_manager.py | 60 ++++++++++++++++ api/services/app_service.py | 2 +- api/services/workflow/defaults.py | 72 ------------------- api/services/workflow/workflow_converter.py | 2 +- api/services/workflow_service.py | 19 ++++- 33 files changed, 314 insertions(+), 80 deletions(-) rename api/core/workflow/entities/{NodeEntities.py => node_entities.py} (100%) create mode 100644 api/core/workflow/nodes/base_node.py create mode 100644 api/core/workflow/nodes/code/__init__.py create mode 100644 api/core/workflow/nodes/code/code_node.py create mode 100644 api/core/workflow/nodes/direct_answer/__init__.py create mode 100644 api/core/workflow/nodes/direct_answer/direct_answer_node.py create mode 100644 api/core/workflow/nodes/http_request/__init__.py create mode 100644 api/core/workflow/nodes/http_request/http_request_node.py create mode 100644 api/core/workflow/nodes/if_else/__init__.py create mode 100644 api/core/workflow/nodes/if_else/if_else_node.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/__init__.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py create mode 100644 api/core/workflow/nodes/llm/__init__.py create mode 100644 api/core/workflow/nodes/llm/llm_node.py create mode 100644 api/core/workflow/nodes/question_classifier/__init__.py create mode 100644 api/core/workflow/nodes/question_classifier/question_classifier_node.py create mode 100644 api/core/workflow/nodes/start/__init__.py create mode 100644 api/core/workflow/nodes/start/start_node.py create mode 100644 api/core/workflow/nodes/template_transform/__init__.py create mode 100644 api/core/workflow/nodes/template_transform/template_transform_node.py create mode 100644 api/core/workflow/nodes/tool/__init__.py create mode 100644 api/core/workflow/nodes/tool/tool_node.py create mode 100644 api/core/workflow/nodes/variable_assigner/__init__.py create mode 100644 api/core/workflow/nodes/variable_assigner/variable_assigner_node.py delete mode 100644 api/services/workflow/defaults.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 7b2411b96fd4a9..66bcbccefe4f00 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -34,7 +34,7 @@ def get(self): parser = reqparse.RequestParser() parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent', 'channel', 'all'], default='all', location='args', required=False) + parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 54585d8519a9a3..5dfb2b14436fcb 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,3 +1,5 @@ +import json + from flask_restful import Resource, marshal_with, reqparse from controllers.console import api @@ -147,7 +149,7 @@ def post(self, app_model: App): } -class DefaultBlockConfigApi(Resource): +class DefaultBlockConfigsApi(Resource): @setup_required @login_required @account_initialization_required @@ -161,6 +163,34 @@ def get(self, app_model: App): return workflow_service.get_default_block_configs() +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, block_type: str): + """ + Get default block config + """ + parser = reqparse.RequestParser() + parser.add_argument('q', type=str, location='args') + args = parser.parse_args() + + filters = None + if args.get('q'): + try: + filters = json.loads(args.get('q')) + except json.JSONDecodeError: + raise ValueError('Invalid filters') + + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_config( + node_type=block_type, + filters=filters + ) + + class ConvertToWorkflowApi(Resource): @setup_required @login_required @@ -188,5 +218,6 @@ def post(self, app_model: App): api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs/:block_type') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2aa649afeab0ff..77e779a0ad89f8 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -30,7 +30,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.NodeEntities import NodeType +from core.workflow.entities.node_entities import NodeType from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageFile diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/node_entities.py similarity index 100% rename from api/core/workflow/entities/NodeEntities.py rename to api/core/workflow/entities/node_entities.py diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py new file mode 100644 index 00000000000000..665338af08b710 --- /dev/null +++ b/api/core/workflow/nodes/base_node.py @@ -0,0 +1,12 @@ +from typing import Optional + + +class BaseNode: + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return {} diff --git a/api/core/workflow/nodes/code/__init__.py b/api/core/workflow/nodes/code/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py new file mode 100644 index 00000000000000..7e69f91d118d7d --- /dev/null +++ b/api/core/workflow/nodes/code/code_node.py @@ -0,0 +1,64 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class CodeNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + if filters and filters.get("code_language") == "javascript": + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "javascript", + "code": "async function main(arg1, arg2) {\n return new Promise((resolve, reject) => {" + "\n if (true) {\n resolve({\n \"result\": arg1 + arg2" + "\n });\n } else {\n reject(\"e\");\n }\n });\n}", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + } + + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "python3", + "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " + "+ arg2\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + } diff --git a/api/core/workflow/nodes/direct_answer/__init__.py b/api/core/workflow/nodes/direct_answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py new file mode 100644 index 00000000000000..c6013974b810da --- /dev/null +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class DirectAnswerNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index e69de29bb2d1d6..f9aea89af7cb3a 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class EndNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py new file mode 100644 index 00000000000000..5be25a9834d7af --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class HttpRequestNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/core/workflow/nodes/if_else/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py new file mode 100644 index 00000000000000..98a5c85db2df32 --- /dev/null +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class IfElseNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py new file mode 100644 index 00000000000000..c6dd6249216faa --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class KnowledgeRetrievalNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py new file mode 100644 index 00000000000000..1c7277e942e99f --- /dev/null +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -0,0 +1,40 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class LLMNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + { + "role": "system", + "text": "You are a helpful AI assistant." + } + ] + }, + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" + }, + "stop": ["Human:"] + } + } + } + } diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py new file mode 100644 index 00000000000000..f676b6372ac3ec --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -0,0 +1,19 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class QuestionClassifierNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "question-classifier", + "config": { + "instructions": "" # TODO + } + } diff --git a/api/core/workflow/nodes/start/__init__.py b/api/core/workflow/nodes/start/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py new file mode 100644 index 00000000000000..8cce655728cb2c --- /dev/null +++ b/api/core/workflow/nodes/start/start_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class StartNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/core/workflow/nodes/template_transform/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py new file mode 100644 index 00000000000000..2bf26e307eaaad --- /dev/null +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -0,0 +1,25 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class TemplateTransformNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "template-transform", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + } + ], + "template": "{{ arg1 }}" + } + } diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/core/workflow/nodes/tool/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py new file mode 100644 index 00000000000000..b805a53d2f3719 --- /dev/null +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class ToolNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py new file mode 100644 index 00000000000000..231a26a6613baf --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class VariableAssignerNode(BaseNode): + pass diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index f7955a87e8cf1b..73e92d5e8948a1 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,9 +1,37 @@ from typing import Optional +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode from extensions.ext_database import db from models.model import App from models.workflow import Workflow +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.DIRECT_ANSWER: DirectAnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, +} + class WorkflowEngineManager: def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: @@ -36,3 +64,35 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: # return published workflow return workflow + + def get_default_configs(self) -> list[dict]: + """ + Get default block configs + """ + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append({ + 'type': node_type.value, + 'config': default_config + }) + + return default_block_configs + + def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_class = node_classes.get(node_type) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config diff --git a/api/services/app_service.py b/api/services/app_service.py index f1d0e3df19393b..6011b6a667935d 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -35,7 +35,7 @@ def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination: filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) elif args['mode'] == 'chat': filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent': + elif args['mode'] == 'agent-chat': filters.append(App.mode == AppMode.AGENT_CHAT.value) elif args['mode'] == 'channel': filters.append(App.mode == AppMode.CHANNEL.value) diff --git a/api/services/workflow/defaults.py b/api/services/workflow/defaults.py deleted file mode 100644 index 67804fa4ebfc32..00000000000000 --- a/api/services/workflow/defaults.py +++ /dev/null @@ -1,72 +0,0 @@ -# default block config -default_block_configs = [ - { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant." - } - ] - }, - "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "prompt": { - "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" - }, - "stop": ["Human:"] - } - } - } - }, - { - "type": "code", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], - "code_language": "python3", - "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " - "+ arg2\n }", - "outputs": [ - { - "variable": "result", - "variable_type": "number" - } - ] - } - }, - { - "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } - }, - { - "type": "question-classifier", - "config": { - "instructions": "" # TODO - } - } -] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 527c6543812a0a..4c7e4db47af5f5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -18,7 +18,7 @@ from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform -from core.workflow.entities.NodeEntities import NodeType +from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from events.app_event import app_was_created from extensions.ext_database import db diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 13ea67d343b6b3..396845d16a426d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -4,6 +4,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.workflow.entities.node_entities import NodeType from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account @@ -121,12 +122,26 @@ def publish_workflow(self, app_model: App, # return new workflow return workflow - def get_default_block_configs(self) -> dict: + def get_default_block_configs(self) -> list[dict]: """ Get default block configs """ # return default block config - return default_block_configs + workflow_engine_manager = WorkflowEngineManager() + return workflow_engine_manager.get_default_configs() + + def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_type = NodeType.value_of(node_type) + + # return default block config + workflow_engine_manager = WorkflowEngineManager() + return workflow_engine_manager.get_default_config(node_type, filters) def convert_to_workflow(self, app_model: App, account: Account) -> App: """ From de40422205ea941e562d29a501ce8782c999cffa Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 13:21:30 +0800 Subject: [PATCH 063/160] lint fix --- api/services/workflow_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 396845d16a426d..0be0783ae0d52f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,7 +10,6 @@ from models.account import Account from models.model import App, AppMode from models.workflow import Workflow, WorkflowType -from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter From 242fcf0145683481d6a8ebce1258fe796472744c Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 13:32:59 +0800 Subject: [PATCH 064/160] fix typo --- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- api/core/app/apps/base_app_runner.py | 2 +- api/core/app/apps/chat/app_runner.py | 2 +- api/core/app/apps/completion/app_runner.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 8b444ef3be3f9b..ad1e6e610d32ea 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -134,7 +134,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): input=query ) - # recalc llm max tokens + # recale llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 30e5cdd6946a14..3c7e55e293ddee 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -107,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): messages_ids=message_file_ids ) - # recalc llm max tokens + # recale llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4e099c9ae1529a..dda240d7789d65 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -84,7 +84,7 @@ def get_pre_calculate_rest_tokens(self, app_record: App, return rest_tokens - def recale_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, + def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 57aca9d3e613d5..bce4606f21ba26 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -189,7 +189,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recale_llm_max_tokens( + self.recalc_llm_max_tokens( model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index c5b8ca6c9a1e2e..d67d485e1d7d89 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -149,7 +149,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recale_llm_max_tokens( + self.recalc_llm_max_tokens( model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) From 3086893ee76e56fdd2155b4270139805d0388c77 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 14:15:17 +0800 Subject: [PATCH 065/160] fix typo --- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index ad1e6e610d32ea..8b444ef3be3f9b 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -134,7 +134,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): input=query ) - # recale llm max tokens + # recalc llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3c7e55e293ddee..30e5cdd6946a14 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -107,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): messages_ids=message_file_ids ) - # recale llm max tokens + # recalc llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( From df753e84a3b8239cf58f04689610ceee6ff4bccd Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 17:23:27 +0800 Subject: [PATCH 066/160] fix workflow api return --- api/controllers/console/app/workflow.py | 87 ++++++++- .../app/apps/advanced_chat/app_generator.py | 16 +- api/core/app/apps/advanced_chat/app_runner.py | 178 +++++++++++++----- api/core/app/entities/queue_entities.py | 1 + api/core/workflow/entities/node_entities.py | 9 + api/core/workflow/entities/variable_pool.py | 82 ++++++++ api/core/workflow/nodes/base_node.py | 37 ++++ api/core/workflow/workflow_engine_manager.py | 34 +++- api/fields/workflow_fields.py | 4 +- api/fields/workflow_run_fields.py | 20 +- api/models/workflow.py | 8 + api/services/workflow_service.py | 39 +++- 12 files changed, 432 insertions(+), 83 deletions(-) create mode 100644 api/core/workflow/entities/variable_pool.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5dfb2b14436fcb..9ee6ca9dbd2d8d 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,18 +1,28 @@ import json +import logging +from typing import Generator +from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import NotFound, InternalServerError +import services from controllers.console import api -from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.error import DraftWorkflowNotExist, ConversationCompletedError from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields +from libs.helper import uuid_value from libs.login import current_user, login_required from models.model import App, AppMode from services.workflow_service import WorkflowService +logger = logging.getLogger(__name__) + + class DraftWorkflowApi(Resource): @setup_required @login_required @@ -59,23 +69,80 @@ def post(self, app_model: App): } +class AdvancedChatDraftWorkflowRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + def post(self, app_model: App): + """ + Run draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, location='json', default='') + parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + try: + response = workflow_service.run_advanced_chat_draft_workflow( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + class DraftWorkflowRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @get_app_model(mode=[AppMode.WORKFLOW]) def post(self, app_model: App): """ Run draft workflow """ - # TODO + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + args = parser.parse_args() + workflow_service = WorkflowService() - workflow_service.run_draft_workflow(app_model=app_model, account=current_user) - # TODO - return { - "result": "success" - } + try: + response = workflow_service.run_draft_workflow( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER + ) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') class WorkflowTaskStopApi(Resource): @@ -214,10 +281,12 @@ def post(self, app_model: App): api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs/:block_type') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' + '/') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ca2f4005479915..918fd4566e3733 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -16,18 +16,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message +from models.workflow import Workflow logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): def generate(self, app_model: App, + workflow: Workflow, user: Union[Account, EndUser], - args: Any, + args: dict, invoke_from: InvokeFrom, stream: bool = True) \ -> Union[dict, Generator]: @@ -35,6 +36,7 @@ def generate(self, app_model: App, Generate App response. :param app_model: App + :param workflow: Workflow :param user: account or end user :param args: request args :param invoke_from: invoke from source @@ -59,16 +61,6 @@ def generate(self, app_model: App, if args.get('conversation_id'): conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) - # get workflow - workflow_engine_manager = WorkflowEngineManager() - if invoke_from == InvokeFrom.DEBUGGER: - workflow = workflow_engine_manager.get_draft_workflow(app_model=app_model) - else: - workflow = workflow_engine_manager.get_published_workflow(app_model=app_model) - - if not workflow: - raise ValueError('Workflow not initialized') - # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 0d701ae2240d96..f853f88af4db8e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,15 +1,20 @@ import logging +import time from typing import cast from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, InvokeFrom, ) +from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationException +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, Message +from models.account import Account +from models.model import App, Conversation, Message, EndUser logger = logging.getLogger(__name__) @@ -38,66 +43,151 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, if not app_record: raise ValueError("App not found") + workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files # moderation + if self.handle_input_moderation( + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=message, + query=query, + queue_manager=queue_manager, + app_generate_entity=application_generate_entity + ): + return + + # fetch user + if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: + user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() + else: + user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + + # RUN WORKFLOW + workflow_engine_manager = WorkflowEngineManager() + result_generator = workflow_engine_manager.run_workflow( + app_model=app_record, + workflow=workflow, + user=user, + user_inputs=inputs, + system_inputs={ + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION: conversation.id, + } + ) + + for result in result_generator: + # todo handle workflow and node event + pass + + + def handle_input_moderation(self, queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: dict, + query: str) -> bool: + """ + Handle input moderation + :param queue_manager: application queue manager + :param app_record: app record + :param app_generate_entity: application generate entity + :param inputs: inputs + :param query: query + :return: + """ try: # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=app_config.tenant_id, - app_generate_entity=application_generate_entity, + tenant_id=app_generate_entity.app_config.tenant_id, + app_generate_entity=app_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: - # TODO - self.direct_output( + self._stream_output( queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=app_generate_entity.stream, + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION ) - return + return True - if query: - # annotation reply - annotation_reply = self.query_app_annotations_to_reply( - app_record=app_record, - message=message, - query=query, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from - ) + return False - if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER - ) - - # TODO - self.direct_output( - queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - prompt_messages=prompt_messages, - text=annotation_reply.content, - stream=application_generate_entity.stream - ) - return - - # check hosting moderation - # TODO - hosting_moderation_result = self.check_hosting_moderation( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - prompt_messages=prompt_messages + def handle_annotation_reply(self, app_record: App, + message: Message, + query: str, + queue_manager: AppQueueManager, + app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + """ + Handle annotation reply + :param app_record: app record + :param message: message + :param query: query + :param queue_manager: application queue manager + :param app_generate_entity: application generate entity + """ + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=app_generate_entity.user_id, + invoke_from=app_generate_entity.invoke_from ) - if hosting_moderation_result: - return + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + + self._stream_output( + queue_manager=queue_manager, + text=annotation_reply.content, + stream=app_generate_entity.stream, + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + ) + return True + + return False - # todo RUN WORKFLOW \ No newline at end of file + def _stream_output(self, queue_manager: AppQueueManager, + text: str, + stream: bool, + stopped_by: QueueStopEvent.StopBy) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param text: text + :param stream: stream + :return: + """ + if stream: + index = 0 + for token in text: + queue_manager.publish_text_chunk(token, PublishFrom.APPLICATION_MANAGER) + index += 1 + time.sleep(0.01) + + queue_manager.publish( + QueueStopEvent(stopped_by=stopped_by), + PublishFrom.APPLICATION_MANAGER + ) + queue_manager.stop_listen() diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 25bdd7d9e39c4e..e5c6a8eff943dd 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -165,6 +165,7 @@ class StopBy(Enum): USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" + INPUT_MODERATION = "input-moderation" event = QueueEvent.STOP stopped_by: StopBy diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 80471cc7024fb5..18f0f7746c1c79 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -30,3 +30,12 @@ def value_of(cls, value: str) -> 'NodeType': if node_type.value == value: return node_type raise ValueError(f'invalid node type value {value}') + + +class SystemVariable(Enum): + """ + System Variables. + """ + QUERY = 'query' + FILES = 'files' + CONVERSATION = 'conversation' diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py new file mode 100644 index 00000000000000..eefee88c07ca48 --- /dev/null +++ b/api/core/workflow/entities/variable_pool.py @@ -0,0 +1,82 @@ +from enum import Enum +from typing import Optional, Union, Any + +from core.workflow.entities.node_entities import SystemVariable + +VariableValue = Union[str, int, float, dict, list] + + +class ValueType(Enum): + """ + Value Type Enum + """ + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY = "array" + FILE = "file" + + +class VariablePool: + variables_mapping = {} + + def __init__(self, system_variables: dict[SystemVariable, Any]) -> None: + # system variables + # for example: + # { + # 'query': 'abc', + # 'files': [] + # } + for system_variable, value in system_variables.items(): + self.append_variable('sys', [system_variable.value], value) + + def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: + """ + Append variable + :param node_id: node id + :param variable_key_list: variable key list, like: ['result', 'text'] + :param value: value + :return: + """ + if node_id not in self.variables_mapping: + self.variables_mapping[node_id] = {} + + variable_key_list_hash = hash(tuple(variable_key_list)) + + self.variables_mapping[node_id][variable_key_list_hash] = value + + def get_variable_value(self, variable_selector: list[str], + target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]: + """ + Get variable + :param variable_selector: include node_id and variables + :param target_value_type: target value type + :return: + """ + if len(variable_selector) < 2: + raise ValueError('Invalid value selector') + + node_id = variable_selector[0] + if node_id not in self.variables_mapping: + return None + + # fetch variable keys, pop node_id + variable_key_list = variable_selector[1:] + + variable_key_list_hash = hash(tuple(variable_key_list)) + + value = self.variables_mapping[node_id].get(variable_key_list_hash) + + if target_value_type: + if target_value_type == ValueType.STRING: + return str(value) + elif target_value_type == ValueType.NUMBER: + return int(value) + elif target_value_type == ValueType.OBJECT: + if not isinstance(value, dict): + raise ValueError('Invalid value type: object') + elif target_value_type == ValueType.ARRAY: + if not isinstance(value, list): + raise ValueError('Invalid value type: array') + + return value diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 665338af08b710..a2751b346fd6c5 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,7 +1,44 @@ +from abc import abstractmethod from typing import Optional +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool + class BaseNode: + _node_type: NodeType + + def __int__(self, node_config: dict) -> None: + self._node_config = node_config + + @abstractmethod + def run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + if variable_pool is None and run_args is None: + raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") + + return self._run( + variable_pool=variable_pool, + run_args=run_args + ) + + @abstractmethod + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + raise NotImplementedError + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 73e92d5e8948a1..5914bfc1522243 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Optional, Union, Generator +from core.memory.token_buffer_memory import TokenBufferMemory from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -14,7 +15,8 @@ from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode from extensions.ext_database import db -from models.model import App +from models.account import Account +from models.model import App, EndUser, Conversation from models.workflow import Workflow node_classes = { @@ -56,13 +58,20 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: return None # fetch published workflow by workflow_id + return self.get_workflow(app_model, app_model.workflow_id) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id workflow = db.session.query(Workflow).filter( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id + Workflow.id == workflow_id ).first() - # return published workflow + # return workflow return workflow def get_default_configs(self) -> list[dict]: @@ -96,3 +105,20 @@ def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None return None return default_config + + def run_workflow(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> Generator: + """ + Run workflow + :param app_model: App instance + :param workflow: Workflow instance + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + # TODO + pass diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index bcb2c318c6a8f9..9919a440e8d91e 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,8 +5,8 @@ workflow_fields = { 'id': fields.String, - 'graph': fields.Nested(simple_account_fields, attribute='graph_dict'), - 'features': fields.Nested(simple_account_fields, attribute='features_dict'), + 'graph': fields.Raw(attribute='graph_dict'), + 'features': fields.Raw(attribute='features_dict'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 37751bc70f9058..85c9c2d2b293f5 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -22,10 +22,10 @@ "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.String, - "inputs": fields.String, + "graph": fields.Raw(attribute='graph_dict'), + "inputs": fields.Raw(attribute='inputs_dict'), "status": fields.String, - "outputs": fields.String, + "outputs": fields.Raw(attribute='outputs_dict'), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, @@ -49,10 +49,10 @@ "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.String, - "inputs": fields.String, + "graph": fields.Raw(attribute='graph_dict'), + "inputs": fields.Raw(attribute='inputs_dict'), "status": fields.String, - "outputs": fields.String, + "outputs": fields.Raw(attribute='outputs_dict'), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, @@ -73,13 +73,13 @@ "node_id": fields.String, "node_type": fields.String, "title": fields.String, - "inputs": fields.String, - "process_data": fields.String, - "outputs": fields.String, + "inputs": fields.Raw(attribute='inputs_dict'), + "process_data": fields.Raw(attribute='process_data_dict'), + "outputs": fields.Raw(attribute='outputs_dict'), "status": fields.String, "error": fields.String, "elapsed_time": fields.Float, - "execution_metadata": fields.String, + "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), "created_at": TimestampField, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), diff --git a/api/models/workflow.py b/api/models/workflow.py index 2540d334025c4e..32ff26196c4378 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -272,6 +272,14 @@ def created_by_end_user(self): return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + @property + def graph_dict(self): + return self.graph if not self.graph else json.loads(self.graph) + + @property + def inputs_dict(self): + return self.inputs if not self.inputs else json.loads(self.inputs) + @property def outputs_dict(self): return self.outputs if not self.outputs else json.loads(self.outputs) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0be0783ae0d52f..37f5c16bec3e26 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,14 +1,16 @@ import json from datetime import datetime -from typing import Optional +from typing import Optional, Union, Any, Generator from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeType from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode +from models.model import App, AppMode, EndUser from models.workflow import Workflow, WorkflowType from services.workflow.workflow_converter import WorkflowConverter @@ -142,6 +144,39 @@ def get_default_block_config(self, node_type: str, filters: Optional[dict] = Non workflow_engine_manager = WorkflowEngineManager() return workflow_engine_manager.get_default_config(node_type, filters) + def run_advanced_chat_draft_workflow(self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom) -> Union[dict, Generator]: + """ + Run advanced chatbot draft workflow + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow + app_generator = AdvancedChatAppGenerator() + response = app_generator.generate( + app_model=app_model, + workflow=draft_workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=True + ) + + return response + + def run_draft_workflow(self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom) -> Union[dict, Generator]: + # TODO + pass + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ Basic mode of chatbot app(expert mode) to workflow From c8a1f923f53f720e84a941456284ee3f3de167c7 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 17:23:35 +0800 Subject: [PATCH 067/160] lint fix --- api/controllers/console/app/workflow.py | 7 +++---- api/core/app/apps/advanced_chat/app_generator.py | 2 +- api/core/app/apps/advanced_chat/app_runner.py | 5 +++-- api/core/workflow/workflow_engine_manager.py | 6 +++--- api/services/workflow_service.py | 3 ++- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 9ee6ca9dbd2d8d..6e77f50e653f42 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,14 +1,14 @@ import json import logging -from typing import Generator +from collections.abc import Generator from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse -from werkzeug.exceptions import NotFound, InternalServerError +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app.error import DraftWorkflowNotExist, ConversationCompletedError +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -19,7 +19,6 @@ from models.model import App, AppMode from services.workflow_service import WorkflowService - logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 918fd4566e3733..937f95679a43a1 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -2,7 +2,7 @@ import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Union from flask import Flask, current_app from pydantic import ValidationError diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index f853f88af4db8e..02d22072dfd511 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -6,7 +6,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, InvokeFrom, + AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationException @@ -14,7 +15,7 @@ from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account -from models.model import App, Conversation, Message, EndUser +from models.model import App, Conversation, EndUser, Message logger = logging.getLogger(__name__) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 5914bfc1522243..8a230487052c49 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,6 +1,6 @@ -from typing import Optional, Union, Generator +from collections.abc import Generator +from typing import Optional, Union -from core.memory.token_buffer_memory import TokenBufferMemory from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -16,7 +16,7 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode from extensions.ext_database import db from models.account import Account -from models.model import App, EndUser, Conversation +from models.model import App, EndUser from models.workflow import Workflow node_classes = { diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 37f5c16bec3e26..2c1b6eb819cb38 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,6 +1,7 @@ import json +from collections.abc import Generator from datetime import datetime -from typing import Optional, Union, Any, Generator +from typing import Optional, Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator From 1a86e79d4a6b32ed818f3278e0377dab17060aba Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 17:23:40 +0800 Subject: [PATCH 068/160] lint fix --- api/core/workflow/entities/variable_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index eefee88c07ca48..e84044dede1982 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional, Union, Any +from typing import Any, Optional, Union from core.workflow.entities.node_entities import SystemVariable From 75f1355d4c742399f247a7dd0737512b6f1741db Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 23:34:23 +0800 Subject: [PATCH 069/160] add few workflow run codes --- api/commands.py | 2 +- api/core/app/app_config/entities.py | 1 + api/core/app/apps/advanced_chat/app_runner.py | 7 +- api/core/callback_handler/__init__.py | 0 .../std_out_callback_handler.py | 157 ------------------ .../workflow_event_trigger_callback.py | 45 +++++ api/core/workflow/callbacks/__init__.py | 0 api/core/workflow/callbacks/base_callback.py | 33 ++++ .../entities/base_node_data_entities.py | 7 + api/core/workflow/nodes/base_node.py | 35 ++-- api/core/workflow/nodes/start/entities.py | 27 +++ api/core/workflow/nodes/start/start_node.py | 19 ++- api/core/workflow/workflow_engine_manager.py | 94 ++++++++++- 13 files changed, 249 insertions(+), 178 deletions(-) create mode 100644 api/core/callback_handler/__init__.py delete mode 100644 api/core/callback_handler/std_out_callback_handler.py create mode 100644 api/core/callback_handler/workflow_event_trigger_callback.py create mode 100644 api/core/workflow/callbacks/__init__.py create mode 100644 api/core/workflow/callbacks/base_callback.py create mode 100644 api/core/workflow/entities/base_node_data_entities.py create mode 100644 api/core/workflow/nodes/start/entities.py diff --git a/api/commands.py b/api/commands.py index 73325620ee8b69..376a394d1e18e6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,7 +15,7 @@ from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation +from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index e155dc1c4dd388..6a521dfcc5b7b5 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -112,6 +112,7 @@ def value_of(cls, value: str) -> 'VariableEntity.Type': max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None + hint: Optional[str] = None class ExternalDataVariableEntity(BaseModel): diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 02d22072dfd511..920adcfb79cc9a 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -10,12 +10,14 @@ InvokeFrom, ) from core.app.entities.queue_entities import QueueStopEvent +from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message +from models.workflow import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -83,13 +85,16 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, result_generator = workflow_engine_manager.run_workflow( app_model=app_record, workflow=workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, user=user, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, - } + }, + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) for result in result_generator: diff --git a/api/core/callback_handler/__init__.py b/api/core/callback_handler/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py deleted file mode 100644 index 1f95471afb2c73..00000000000000 --- a/api/core/callback_handler/std_out_callback_handler.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import sys -from typing import Any, Optional, Union - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - - -class DifyStdOutCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.color = color - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - print_text("\n[on_chat_model_start]\n", color='blue') - for sub_messages in messages: - for sub_message in sub_messages: - print_text(str(sub_message) + "\n", color='blue') - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - """Print out the prompts.""" - print_text("\n[on_llm_start]\n", color='blue') - print_text(prompts[0] + "\n", color='blue') - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str( - response.llm_output) + "\n", color='blue') - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') - - def on_chain_start( - self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - chain_type = serialized['id'][-1] - print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - - def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink') - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_start] " + str(serialized), color='yellow') - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = action.tool_input - try: - action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - except ValueError: - thought = '' - - log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" - print_text("\n[on_agent_action]\n" + log + "\n", color='green') - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - print_text("\n[on_tool_end]\n", color='yellow') - if observation_prefix: - print_text(f"\n{observation_prefix}") - print_text(output, color='yellow') - if llm_prefix: - print_text(f"\n{llm_prefix}") - print_text("\n") - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow') - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run when agent ends.""" - print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chat_model(self) -> bool: - """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - -class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): - """Callback handler for streaming. Only works with LLMs that support streaming.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - sys.stdout.write(token) - sys.stdout.flush() diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py new file mode 100644 index 00000000000000..2f81f27426e23d --- /dev/null +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -0,0 +1,45 @@ +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from models.workflow import WorkflowRun, WorkflowNodeExecution + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager): + self._queue_manager = queue_manager + + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + self._queue_manager.publish_workflow_started( + workflow_run_id=workflow_run.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + self._queue_manager.publish_workflow_finished( + workflow_run_id=workflow_run.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish_node_started( + workflow_node_execution_id=workflow_node_execution.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + self._queue_manager.publish_node_finished( + workflow_node_execution_id=workflow_node_execution.id, + pub_from=PublishFrom.TASK_PIPELINE + ) diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_callback.py new file mode 100644 index 00000000000000..a564af498cc6a1 --- /dev/null +++ b/api/core/workflow/callbacks/base_callback.py @@ -0,0 +1,33 @@ +from abc import abstractmethod + +from models.workflow import WorkflowRun, WorkflowNodeExecution + + +class BaseWorkflowCallback: + @abstractmethod + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + raise NotImplementedError diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py new file mode 100644 index 00000000000000..32b93ea094c4fe --- /dev/null +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -0,0 +1,7 @@ +from abc import ABC + +from pydantic import BaseModel + + +class BaseNodeData(ABC, BaseModel): + pass diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a2751b346fd6c5..a95a232ae649b6 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,21 +1,37 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Type +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool class BaseNode: _node_type: NodeType + _node_data_cls: Type[BaseNodeData] - def __int__(self, node_config: dict) -> None: - self._node_config = node_config + def __init__(self, config: dict) -> None: + self._node_id = config.get("id") + if not self._node_id: + raise ValueError("Node ID is required.") + + self._node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + raise NotImplementedError + def run(self, variable_pool: Optional[VariablePool] = None, run_args: Optional[dict] = None) -> dict: """ - Run node + Run node entry :param variable_pool: variable pool :param run_args: run args :return: @@ -28,17 +44,6 @@ def run(self, variable_pool: Optional[VariablePool] = None, run_args=run_args ) - @abstractmethod - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: - """ - Run node - :param variable_pool: variable pool - :param run_args: run args - :return: - """ - raise NotImplementedError - @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py new file mode 100644 index 00000000000000..25b27cf1922886 --- /dev/null +++ b/api/core/workflow/nodes/start/entities.py @@ -0,0 +1,27 @@ +from typing import Optional + +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType + + +class StartNodeData(BaseNodeData): + """ + - title (string) 节点标题 + - desc (string) optional 节点描述 + - type (string) 节点类型,固定为 start + - variables (array[object]) 表单变量列表 + - type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义) + - label (string) 控件展示标签名 + - variable (string) 变量 key + - max_length (int) 最大长度,适用于 text-input 和 paragraph + - default (string) optional 默认值 + - required (bool) optional是否必填,默认 false + - hint (string) optional 提示信息 + - options (array[string]) 选项值(仅 select 可用) + """ + type: str = NodeType.START.value + + title: str + desc: Optional[str] = None + variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 8cce655728cb2c..014a146c93a550 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,22 @@ +from typing import Type, Optional + +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.start.entities import StartNodeData class StartNode(BaseNode): - pass + _node_type = NodeType.START + _node_data_cls = StartNodeData + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + pass + diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 8a230487052c49..afa4dbb321fac8 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,6 +1,8 @@ +import json from collections.abc import Generator from typing import Optional, Union +from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -17,7 +19,7 @@ from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole node_classes = { NodeType.START: StartNode, @@ -108,17 +110,103 @@ def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None def run_workflow(self, app_model: App, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, - system_inputs: Optional[dict] = None) -> Generator: + system_inputs: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> Generator: """ Run workflow :param app_model: App instance :param workflow: Workflow instance + :param triggered_from: triggered from :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks :return: """ - # TODO + # fetch workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + # init workflow run + workflow_run = self._init_workflow_run( + workflow=workflow, + triggered_from=triggered_from, + user=user, + user_inputs=user_inputs, + system_inputs=system_inputs + ) + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started(workflow_run) + pass + + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + try: + db.session.begin() + + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .for_update() \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **system_inputs}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by_id=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + except: + db.session.rollback() + raise + + return workflow_run + + def _get_entry_node(self, graph: dict) -> Optional[StartNode]: + """ + Get entry node + :param graph: workflow graph + :return: + """ + nodes = graph.get('nodes') + if not nodes: + return None + + for node_config in nodes.items(): + if node_config.get('type') == NodeType.START.value: + return StartNode(config=node_config) + + return None From bc4edbfc2bb5526a062248589c9c1f3aee623fe1 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 23:34:28 +0800 Subject: [PATCH 070/160] lint fix --- api/core/callback_handler/workflow_event_trigger_callback.py | 2 +- api/core/workflow/callbacks/base_callback.py | 2 +- api/core/workflow/nodes/base_node.py | 4 ++-- api/core/workflow/nodes/start/start_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py index 2f81f27426e23d..e1d2413534ebac 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -1,6 +1,6 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.workflow.callbacks.base_callback import BaseWorkflowCallback -from models.workflow import WorkflowRun, WorkflowNodeExecution +from models.workflow import WorkflowNodeExecution, WorkflowRun class WorkflowEventTriggerCallback(BaseWorkflowCallback): diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_callback.py index a564af498cc6a1..76fe4d96d53d94 100644 --- a/api/core/workflow/callbacks/base_callback.py +++ b/api/core/workflow/callbacks/base_callback.py @@ -1,6 +1,6 @@ from abc import abstractmethod -from models.workflow import WorkflowRun, WorkflowNodeExecution +from models.workflow import WorkflowNodeExecution, WorkflowRun class BaseWorkflowCallback: diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a95a232ae649b6..6f28a3f1040a42 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional, Type +from typing import Optional from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType @@ -8,7 +8,7 @@ class BaseNode: _node_type: NodeType - _node_data_cls: Type[BaseNodeData] + _node_data_cls: type[BaseNodeData] def __init__(self, config: dict) -> None: self._node_id = config.get("id") diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 014a146c93a550..e218cced3d1230 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,4 @@ -from typing import Type, Optional +from typing import Optional from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index afa4dbb321fac8..3ad36fe1d26d82 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -19,7 +19,7 @@ from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole +from models.workflow import CreatedByRole, Workflow, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom node_classes = { NodeType.START: StartNode, From a5de7b10f36d4854c70630cf19c956854c1eefef Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 5 Mar 2024 17:35:05 +0800 Subject: [PATCH 071/160] update ruff check --- web/.husky/pre-commit | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index dfd6ec02095e35..1f8ae9a8d39cf1 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -24,7 +24,21 @@ done if $api_modified; then echo "Running Ruff linter on api module" - ./dev/reformat + + # python style checks rely on `ruff` in path + if ! command -v ruff &> /dev/null; then + echo "Installing Ruff ..." + pip install ruff + fi + + ruff check ./api + result=$? + + if [ $result -ne 0 ]; then + echo "Please run 'dev/reformat' to fix the fixable linting errors." + fi + + exit $result fi if $web_modified; then From 79a10e97295e2ae92ee819904b5f97b3f7b1092b Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 13:26:14 +0800 Subject: [PATCH 072/160] add updated_at to sync workflow api --- api/controllers/console/app/workflow.py | 7 +- api/core/app/apps/advanced_chat/app_runner.py | 7 +- .../entities/base_node_data_entities.py | 6 +- .../workflow/entities/workflow_entities.py | 16 ++ api/core/workflow/nodes/base_node.py | 24 ++- api/core/workflow/nodes/start/entities.py | 4 - api/core/workflow/nodes/start/start_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 184 +++++++++++++++++- api/libs/helper.py | 2 +- web/.husky/pre-commit | 12 +- 10 files changed, 233 insertions(+), 31 deletions(-) create mode 100644 api/core/workflow/entities/workflow_entities.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6e77f50e653f42..4f8df6bcec5a34 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -14,7 +14,7 @@ from controllers.console.wraps import account_initialization_required from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields -from libs.helper import uuid_value +from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode from services.workflow_service import WorkflowService @@ -56,7 +56,7 @@ def post(self, app_model: App): args = parser.parse_args() workflow_service = WorkflowService() - workflow_service.sync_draft_workflow( + workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args.get('graph'), features=args.get('features'), @@ -64,7 +64,8 @@ def post(self, app_model: App): ) return { - "result": "success" + "result": "success", + "updated_at": TimestampField().format(workflow.updated_at) } diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 920adcfb79cc9a..898091f52c2b1b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -82,7 +82,7 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() - result_generator = workflow_engine_manager.run_workflow( + workflow_engine_manager.run_workflow( app_model=app_record, workflow=workflow, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING @@ -97,11 +97,6 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) - for result in result_generator: - # todo handle workflow and node event - pass - - def handle_input_moderation(self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index 32b93ea094c4fe..afa6ddff047a2e 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -1,7 +1,11 @@ from abc import ABC +from typing import Optional from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): - pass + type: str + + title: str + desc: Optional[str] = None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py new file mode 100644 index 00000000000000..21126caf30f523 --- /dev/null +++ b/api/core/workflow/entities/workflow_entities.py @@ -0,0 +1,16 @@ +from decimal import Decimal + +from core.workflow.entities.variable_pool import VariablePool +from models.workflow import WorkflowNodeExecution, WorkflowRun + + +class WorkflowRunState: + workflow_run: WorkflowRun + start_at: float + variable_pool: VariablePool + + total_tokens: int = 0 + total_price: Decimal = Decimal(0) + currency: str = "USD" + + workflow_node_executions: list[WorkflowNodeExecution] = [] diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6f28a3f1040a42..314dfb8f224696 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,21 +1,25 @@ from abc import abstractmethod from typing import Optional +from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool class BaseNode: - _node_type: NodeType _node_data_cls: type[BaseNodeData] + _node_type: NodeType + + node_id: str + node_data: BaseNodeData def __init__(self, config: dict) -> None: - self._node_id = config.get("id") - if not self._node_id: + self.node_id = config.get("id") + if not self.node_id: raise ValueError("Node ID is required.") - self._node_data = self._node_data_cls(**config.get("data", {})) + self.node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod def _run(self, variable_pool: Optional[VariablePool] = None, @@ -29,11 +33,13 @@ def _run(self, variable_pool: Optional[VariablePool] = None, raise NotImplementedError def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> dict: """ Run node entry :param variable_pool: variable pool :param run_args: run args + :param callbacks: callbacks :return: """ if variable_pool is None and run_args is None: @@ -52,3 +58,11 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: :return: """ return {} + + @property + def node_type(self) -> NodeType: + """ + Get node type + :return: + """ + return self._node_type diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 25b27cf1922886..64687db042f1dc 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType @@ -22,6 +20,4 @@ class StartNodeData(BaseNodeData): """ type: str = NodeType.START.value - title: str - desc: Optional[str] = None variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index e218cced3d1230..74d854143603e2 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -7,8 +7,8 @@ class StartNode(BaseNode): - _node_type = NodeType.START _node_data_cls = StartNodeData + node_type = NodeType.START def _run(self, variable_pool: Optional[VariablePool] = None, run_args: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3ad36fe1d26d82..0ec93dd4b21018 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,9 +1,12 @@ import json -from collections.abc import Generator +import time from typing import Optional, Union from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_entities import WorkflowRunState +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode @@ -19,7 +22,16 @@ from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import CreatedByRole, Workflow, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) node_classes = { NodeType.START: StartNode, @@ -114,7 +126,7 @@ def run_workflow(self, app_model: App, user: Union[Account, EndUser], user_inputs: dict, system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> Generator: + callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param app_model: App instance @@ -140,11 +152,66 @@ def run_workflow(self, app_model: App, system_inputs=system_inputs ) + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow_run=workflow_run, + start_at=time.perf_counter(), + variable_pool=VariablePool( + system_variables=system_inputs, + ) + ) + if callbacks: for callback in callbacks: callback.on_workflow_run_started(workflow_run) - pass + # fetch start node + start_node = self._get_entry_node(graph) + if not start_node: + self._workflow_run_failed( + workflow_run_state=workflow_run_state, + error='Start node not found in workflow graph', + callbacks=callbacks + ) + return + + try: + predecessor_node = None + current_node = start_node + while True: + # run workflow + self._run_workflow_node( + workflow_run_state=workflow_run_state, + node=current_node, + predecessor_node=predecessor_node, + callbacks=callbacks + ) + + if current_node.node_type == NodeType.END: + break + + # todo fetch next node until end node finished or no next node + current_node = None + + if not current_node: + break + + predecessor_node = current_node + # or max steps 30 reached + # or max execution time 10min reached + except Exception as e: + self._workflow_run_failed( + workflow_run_state=workflow_run_state, + error=str(e), + callbacks=callbacks + ) + return + + # workflow run success + self._workflow_run_success( + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, @@ -184,7 +251,7 @@ def _init_workflow_run(self, workflow: Workflow, status=WorkflowRunStatus.RUNNING.value, created_by_role=(CreatedByRole.ACCOUNT.value if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by_id=user.id + created_by=user.id ) db.session.add(workflow_run) @@ -195,6 +262,33 @@ def _init_workflow_run(self, workflow: Workflow, return workflow_run + def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, + error: str, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run_state: workflow run state + :param error: error message + :param callbacks: workflow callbacks + :return: + """ + workflow_run = workflow_run_state.workflow_run + workflow_run.status = WorkflowRunStatus.FAILED.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at + workflow_run.total_tokens = workflow_run_state.total_tokens + workflow_run.total_price = workflow_run_state.total_price + workflow_run.currency = workflow_run_state.currency + workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_finished(workflow_run) + + return workflow_run + def _get_entry_node(self, graph: dict) -> Optional[StartNode]: """ Get entry node @@ -210,3 +304,83 @@ def _get_entry_node(self, graph: dict) -> Optional[StartNode]: return StartNode(config=node_config) return None + + def _run_workflow_node(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + # init workflow node execution + start_at = time.perf_counter() + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run_state=workflow_run_state, + node=node, + predecessor_node=predecessor_node, + ) + + # add to workflow node executions + workflow_run_state.workflow_node_executions.append(workflow_node_execution) + + try: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool, + callbacks=callbacks + ) + except Exception as e: + # node run failed + self._workflow_node_execution_failed( + workflow_node_execution=workflow_node_execution, + error=str(e), + callbacks=callbacks + ) + raise + + # node run success + self._workflow_node_execution_success( + workflow_node_execution=workflow_node_execution, + result=node_run_result, + callbacks=callbacks + ) + + return workflow_node_execution + + def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + """ + Init workflow node execution from workflow run + :param workflow_run_state: workflow run state + :param node: current node + :param predecessor_node: predecessor node if exists + :param callbacks: workflow callbacks + :return: + """ + workflow_run = workflow_run_state.workflow_run + + # init workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None, + index=len(workflow_run_state.workflow_node_executions) + 1, + node_id=node.node_id, + node_type=node.node_type.value, + title=node.node_data.title, + type=node.node_type.value, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_started(workflow_node_execution) + + return workflow_node_execution diff --git a/api/libs/helper.py b/api/libs/helper.py index a35f4ad4710868..3eb14c50f049e3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -15,7 +15,7 @@ def run(script): class TimestampField(fields.Raw): - def format(self, value): + def format(self, value) -> int: return int(value.timestamp()) diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index 1f8ae9a8d39cf1..4bc7fb77abe728 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -31,14 +31,16 @@ if $api_modified; then pip install ruff fi - ruff check ./api - result=$? + ruff check ./api || status=$? - if [ $result -ne 0 ]; then + status=${status:-0} + + + if [ $status -ne 0 ]; then + echo "Ruff linter on api module error, exit code: $status" echo "Please run 'dev/reformat' to fix the fixable linting errors." + exit 1 fi - - exit $result fi if $web_modified; then From dd50deaa438dc264ebfcbaf30e9fab30824ea681 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 13:45:01 +0800 Subject: [PATCH 073/160] fix audio voice arg --- api/services/audio_service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a658487f83a2d..d013a51c3e6507 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -64,7 +64,8 @@ def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[st return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: str, streaming: bool, end_user: Optional[str] = None): + def transcript_tts(cls, app_model: App, text: str, streaming: bool, + voice: Optional[str] = None, end_user: Optional[str] = None): if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: workflow = app_model.workflow if workflow is None: @@ -74,14 +75,14 @@ def transcript_tts(cls, app_model: App, text: str, streaming: bool, end_user: Op if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') + voice = features_dict['text_to_speech'].get('voice') if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict if not text_to_speech_dict.get('enabled'): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice'), + voice = text_to_speech_dict.get('voice') if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( From 7d28fe8ea5d0b295a4d2e0073c8593fcc86f1870 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 17:43:42 +0800 Subject: [PATCH 074/160] completed workflow engine main logic --- api/core/app/apps/advanced_chat/app_runner.py | 3 +- .../advanced_chat/generate_task_pipeline.py | 2 - .../workflow_event_trigger_callback.py | 11 +- ..._callback.py => base_workflow_callback.py} | 8 + api/core/workflow/entities/node_entities.py | 21 ++ .../workflow/entities/workflow_entities.py | 9 +- api/core/workflow/nodes/base_node.py | 48 ++- api/core/workflow/workflow_engine_manager.py | 334 +++++++++++++++--- api/fields/workflow_run_fields.py | 6 - api/models/workflow.py | 4 - 10 files changed, 366 insertions(+), 80 deletions(-) rename api/core/workflow/callbacks/{base_callback.py => base_workflow_callback.py} (85%) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 898091f52c2b1b..c5ffa8016567ac 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -83,7 +83,6 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( - app_model=app_record, workflow=workflow, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, @@ -94,7 +93,7 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)], ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 77e779a0ad89f8..cfeb46f05a2b80 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -253,8 +253,6 @@ def _process_stream_response(self) -> Generator: 'error': workflow_run.error, 'elapsed_time': workflow_run.elapsed_time, 'total_tokens': workflow_run.total_tokens, - 'total_price': workflow_run.total_price, - 'currency': workflow_run.currency, 'total_steps': workflow_run.total_steps, 'created_at': int(workflow_run.created_at.timestamp()), 'finished_at': int(workflow_run.finished_at.timestamp()) diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py index e1d2413534ebac..80dabc75489738 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -1,5 +1,5 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from models.workflow import WorkflowNodeExecution, WorkflowRun @@ -43,3 +43,12 @@ def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNod workflow_node_execution_id=workflow_node_execution.id, pub_from=PublishFrom.TASK_PIPELINE ) + + def on_text_chunk(self, text: str) -> None: + """ + Publish text chunk + """ + self._queue_manager.publish_text_chunk( + text=text, + pub_from=PublishFrom.TASK_PIPELINE + ) diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py similarity index 85% rename from api/core/workflow/callbacks/base_callback.py rename to api/core/workflow/callbacks/base_workflow_callback.py index 76fe4d96d53d94..3425b2b03c7111 100644 --- a/api/core/workflow/callbacks/base_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -31,3 +31,11 @@ def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNod Workflow node execute finished """ raise NotImplementedError + + @abstractmethod + def on_text_chunk(self, text: str) -> None: + """ + Publish text chunk + """ + raise NotImplementedError + diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 18f0f7746c1c79..af539692ef8eaf 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,4 +1,9 @@ from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from models.workflow import WorkflowNodeExecutionStatus class NodeType(Enum): @@ -39,3 +44,19 @@ class SystemVariable(Enum): QUERY = 'query' FILES = 'files' CONVERSATION = 'conversation' + + +class NodeRunResult(BaseModel): + """ + Node Run Result. + """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING + + inputs: Optional[dict] = None # node inputs + process_data: Optional[dict] = None # process data + outputs: Optional[dict] = None # node outputs + metadata: Optional[dict] = None # node metadata + + edge_source_handle: Optional[str] = None # source handle id of node with multiple branches + + error: Optional[str] = None # error message if status is failed diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 21126caf30f523..0d78e4c4f1fff5 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -1,5 +1,3 @@ -from decimal import Decimal - from core.workflow.entities.variable_pool import VariablePool from models.workflow import WorkflowNodeExecution, WorkflowRun @@ -10,7 +8,10 @@ class WorkflowRunState: variable_pool: VariablePool total_tokens: int = 0 - total_price: Decimal = Decimal(0) - currency: str = "USD" workflow_node_executions: list[WorkflowNodeExecution] = [] + + def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None: + self.workflow_run = workflow_run + self.start_at = start_at + self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 314dfb8f224696..efffdfae1aeeba 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,10 +1,11 @@ from abc import abstractmethod from typing import Optional -from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from models.workflow import WorkflowNodeExecutionStatus class BaseNode: @@ -13,17 +14,23 @@ class BaseNode: node_id: str node_data: BaseNodeData + node_run_result: Optional[NodeRunResult] = None - def __init__(self, config: dict) -> None: + stream_output_supported: bool = False + callbacks: list[BaseWorkflowCallback] + + def __init__(self, config: dict, + callbacks: list[BaseWorkflowCallback] = None) -> None: self.node_id = config.get("id") if not self.node_id: raise ValueError("Node ID is required.") self.node_data = self._node_data_cls(**config.get("data", {})) + self.callbacks = callbacks or [] @abstractmethod def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node :param variable_pool: variable pool @@ -33,22 +40,41 @@ def _run(self, variable_pool: Optional[VariablePool] = None, raise NotImplementedError def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node entry :param variable_pool: variable pool :param run_args: run args - :param callbacks: callbacks :return: """ if variable_pool is None and run_args is None: raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - return self._run( - variable_pool=variable_pool, - run_args=run_args - ) + try: + result = self._run( + variable_pool=variable_pool, + run_args=run_args + ) + except Exception as e: + # process unhandled exception + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + self.node_run_result = result + return result + + def publish_text_chunk(self, text: str) -> None: + """ + Publish text chunk + :param text: chunk text + :return: + """ + if self.stream_output_supported: + if self.callbacks: + for callback in self.callbacks: + callback.on_text_chunk(text) @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0ec93dd4b21018..908b6849309105 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,10 +1,11 @@ import json import time +from datetime import datetime from typing import Optional, Union -from core.workflow.callbacks.base_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.code_node import CodeNode @@ -31,6 +32,7 @@ WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, + WorkflowType, ) node_classes = { @@ -120,8 +122,7 @@ def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None return default_config - def run_workflow(self, app_model: App, - workflow: Workflow, + def run_workflow(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, @@ -129,7 +130,6 @@ def run_workflow(self, app_model: App, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow - :param app_model: App instance :param workflow: Workflow instance :param triggered_from: triggered from :param user: account or end user @@ -143,13 +143,23 @@ def run_workflow(self, app_model: App, if not graph: raise ValueError('workflow graph not found') + if 'nodes' not in graph or 'edges' not in graph: + raise ValueError('nodes or edges not found in workflow graph') + + if isinstance(graph.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if isinstance(graph.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + # init workflow run workflow_run = self._init_workflow_run( workflow=workflow, triggered_from=triggered_from, user=user, user_inputs=user_inputs, - system_inputs=system_inputs + system_inputs=system_inputs, + callbacks=callbacks ) # init workflow run state @@ -161,44 +171,54 @@ def run_workflow(self, app_model: App, ) ) - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started(workflow_run) - - # fetch start node - start_node = self._get_entry_node(graph) - if not start_node: - self._workflow_run_failed( - workflow_run_state=workflow_run_state, - error='Start node not found in workflow graph', - callbacks=callbacks - ) - return + # fetch predecessor node ids before end node (include: llm, direct answer) + streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph) try: predecessor_node = None - current_node = start_node while True: - # run workflow - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=current_node, + # get next node, multiple target nodes in the future + next_node = self._get_next_node( + graph=graph, predecessor_node=predecessor_node, callbacks=callbacks ) - if current_node.node_type == NodeType.END: + if not next_node: break - # todo fetch next node until end node finished or no next node - current_node = None + # check if node is streamable + if next_node.node_id in streamable_node_ids: + next_node.stream_output_supported = True - if not current_node: - break + # max steps 30 reached + if len(workflow_run_state.workflow_node_executions) > 30: + raise ValueError('Max steps 30 reached.') - predecessor_node = current_node - # or max steps 30 reached # or max execution time 10min reached + if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600): + raise ValueError('Max execution time 10min reached.') + + # run workflow, run multiple target nodes in the future + self._run_workflow_node( + workflow_run_state=workflow_run_state, + node=next_node, + predecessor_node=predecessor_node, + callbacks=callbacks + ) + + if next_node.node_type == NodeType.END: + break + + predecessor_node = next_node + + if not predecessor_node and not next_node: + self._workflow_run_failed( + workflow_run_state=workflow_run_state, + error='Start node not found in workflow graph.', + callbacks=callbacks + ) + return except Exception as e: self._workflow_run_failed( workflow_run_state=workflow_run_state, @@ -213,11 +233,40 @@ def run_workflow(self, app_model: App, callbacks=callbacks ) + def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param workflow: Workflow instance + :param graph: workflow graph + :return: + """ + workflow_type = WorkflowType.value_of(workflow.type) + + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + if workflow_type == WorkflowType.WORKFLOW: + if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': + end_node_ids.append(node_config.get('id')) + else: + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids + def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: + system_inputs: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: """ Init workflow run :param workflow: Workflow instance @@ -225,6 +274,7 @@ def _init_workflow_run(self, workflow: Workflow, :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks :return: """ try: @@ -260,6 +310,39 @@ def _init_workflow_run(self, workflow: Workflow, db.session.rollback() raise + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started(workflow_run) + + return workflow_run + + def _workflow_run_success(self, workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + """ + Workflow run success + :param workflow_run_state: workflow run state + :param callbacks: workflow callbacks + :return: + """ + workflow_run = workflow_run_state.workflow_run + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + + # fetch last workflow_node_executions + last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] + if last_workflow_node_execution: + workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs) + + workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at + workflow_run.total_tokens = workflow_run_state.total_tokens + workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_finished(workflow_run) + return workflow_run def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, @@ -277,9 +360,8 @@ def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, workflow_run.error = error workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at workflow_run.total_tokens = workflow_run_state.total_tokens - workflow_run.total_price = workflow_run_state.total_price - workflow_run.currency = workflow_run_state.currency workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) + workflow_run.finished_at = datetime.utcnow() db.session.commit() @@ -289,21 +371,77 @@ def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, return workflow_run - def _get_entry_node(self, graph: dict) -> Optional[StartNode]: + def _get_next_node(self, graph: dict, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: """ - Get entry node + Get next node + multiple target nodes in the future. :param graph: workflow graph + :param predecessor_node: predecessor node + :param callbacks: workflow callbacks :return: """ nodes = graph.get('nodes') if not nodes: return None - for node_config in nodes.items(): - if node_config.get('type') == NodeType.START.value: - return StartNode(config=node_config) + if not predecessor_node: + for node_config in nodes: + if node_config.get('type') == NodeType.START.value: + return StartNode(config=node_config) + else: + edges = graph.get('edges') + source_node_id = predecessor_node.node_id + + # fetch all outgoing edges from source node + outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] + if not outgoing_edges: + return None + + # fetch target node id from outgoing edges + outgoing_edge = None + source_handle = predecessor_node.node_run_result.edge_source_handle + if source_handle: + for edge in outgoing_edges: + if edge.get('source_handle') and edge.get('source_handle') == source_handle: + outgoing_edge = edge + break + else: + outgoing_edge = outgoing_edges[0] + + if not outgoing_edge: + return None + + target_node_id = outgoing_edge.get('target') + + # fetch target node from target node id + target_node_config = None + for node in nodes: + if node.get('id') == target_node_id: + target_node_config = node + break + + if not target_node_config: + return None - return None + # get next node + target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) + + return target_node( + config=target_node_config, + callbacks=callbacks + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + # TODO check queue is stopped + return time.perf_counter() - start_at > max_execution_time def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node: BaseNode, @@ -320,28 +458,41 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, # add to workflow node executions workflow_run_state.workflow_node_executions.append(workflow_node_execution) - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool, - callbacks=callbacks - ) - except Exception as e: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool + ) + + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed self._workflow_node_execution_failed( workflow_node_execution=workflow_node_execution, - error=str(e), + start_at=start_at, + error=node_run_result.error, callbacks=callbacks ) - raise + raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") # node run success self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, + start_at=start_at, result=node_run_result, callbacks=callbacks ) + for variable_key, variable_value in node_run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + + if node_run_result.metadata.get('total_tokens'): + workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) + return workflow_node_execution def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState, @@ -384,3 +535,86 @@ def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRun callback.on_workflow_node_execute_started(workflow_node_execution) return workflow_node_execution + + def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + result: NodeRunResult, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + """ + Workflow node execution success + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param result: node run result + :param callbacks: workflow callbacks + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.inputs = json.dumps(result.inputs) + workflow_node_execution.process_data = json.dumps(result.process_data) + workflow_node_execution.outputs = json.dumps(result.outputs) + workflow_node_execution.execution_metadata = json.dumps(result.metadata) + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_finished(workflow_node_execution) + + return workflow_node_execution + + def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + error: str, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param error: error message + :param callbacks: workflow callbacks + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_finished(workflow_node_execution) + + return workflow_node_execution + + def _append_variables_recursively(self, variable_pool: VariablePool, + node_id: str, + variable_key_list: list[str], + variable_value: VariableValue): + """ + Append variables recursively + :param variable_pool: variable pool + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + variable_pool.append_variable( + node_id=node_id, + variable_key_list=variable_key_list, + value=variable_value + ) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + variable_pool=variable_pool, + node_id=node_id, + variable_key_list=new_key_list, + variable_value=value + ) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 85c9c2d2b293f5..572f472f1f083f 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -11,8 +11,6 @@ "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, "total_steps": fields.Integer, "created_at": TimestampField, "finished_at": TimestampField @@ -29,8 +27,6 @@ "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, "total_steps": fields.Integer, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_at": TimestampField, @@ -56,8 +52,6 @@ "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, "total_steps": fields.Integer, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), diff --git a/api/models/workflow.py b/api/models/workflow.py index 32ff26196c4378..032134a0d1c3ad 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -216,8 +216,6 @@ class WorkflowRun(db.Model): - error (string) `optional` Error reason - elapsed_time (float) `optional` Time consumption (s) - total_tokens (int) `optional` Total tokens used - - total_price (decimal) `optional` Total cost - - currency (string) `optional` Currency, such as USD / RMB - total_steps (int) Total steps (redundant), default 0 - created_by_role (string) Creator role @@ -251,8 +249,6 @@ class WorkflowRun(db.Model): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - total_price = db.Column(db.Numeric(10, 7)) - currency = db.Column(db.String(255)) total_steps = db.Column(db.Integer, server_default=db.text('0')) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) From a1bc6b50c5488bee749d1111dc979ec69255a447 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 22:10:49 +0800 Subject: [PATCH 075/160] refactor workflow generate pipeline --- api/controllers/console/app/completion.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/web/completion.py | 2 +- api/core/agent/base_agent_runner.py | 2 +- api/core/agent/cot_agent_runner.py | 31 +- api/core/agent/fc_agent_runner.py | 30 +- api/core/app/app_queue_manager.py | 335 -------------- .../app/apps/advanced_chat/app_generator.py | 5 +- api/core/app/apps/advanced_chat/app_runner.py | 19 +- .../advanced_chat/generate_task_pipeline.py | 12 +- api/core/app/apps/agent_chat/app_generator.py | 5 +- api/core/app/apps/agent_chat/app_runner.py | 10 +- api/core/app/apps/base_app_queue_manager.py | 181 ++++++++ api/core/app/apps/base_app_runner.py | 58 ++- api/core/app/apps/chat/app_generator.py | 5 +- api/core/app/apps/chat/app_runner.py | 10 +- api/core/app/apps/completion/app_generator.py | 7 +- api/core/app/apps/completion/app_runner.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 25 +- .../app/apps/message_based_app_generator.py | 2 +- .../apps/message_based_app_queue_manager.py | 29 ++ api/core/app/apps/workflow/app_generator.py | 164 +++++++ .../app/apps/workflow/app_queue_manager.py | 23 + api/core/app/apps/workflow/app_runner.py | 156 +++++++ .../apps/workflow/generate_task_pipeline.py | 408 ++++++++++++++++++ api/core/app/entities/app_invoke_entities.py | 4 +- .../index_tool_callback_handler.py | 8 +- .../workflow_event_trigger_callback.py | 41 +- api/core/moderation/output_moderation.py | 19 +- api/services/workflow_service.py | 21 +- 31 files changed, 1175 insertions(+), 445 deletions(-) delete mode 100644 api/core/app/app_queue_manager.py create mode 100644 api/core/app/apps/base_app_queue_manager.py create mode 100644 api/core/app/apps/message_based_app_queue_manager.py create mode 100644 api/core/app/apps/workflow/app_generator.py create mode 100644 api/core/app/apps/workflow/app_queue_manager.py create mode 100644 api/core/app/apps/workflow/app_runner.py create mode 100644 api/core/app/apps/workflow/generate_task_pipeline.py diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fd6cfadfef8575..a7fd0164d86f73 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index dd531974fa10cd..b8a5be0df0768a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -21,7 +21,7 @@ ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5c488093fa7b1e..410fb5bffd8e4e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -19,7 +19,7 @@ ProviderQuotaExceededError, ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 785e2b8d6b9225..ed1378e7e3d85f 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -20,7 +20,7 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 236a5d9cf77244..0901b7e96598a0 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -6,8 +6,8 @@ from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity -from core.app.app_queue_manager import AppQueueManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 8b444ef3be3f9b..cbb19aca53ad75 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -5,7 +5,8 @@ from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit -from core.app.app_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -121,7 +122,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ) if iteration_step > 1: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt messages prompt_messages = self._organize_cot_prompt_messages( @@ -163,7 +166,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) for chunk in react_chunks: if isinstance(chunk, dict): @@ -225,7 +230,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): llm_usage=usage_dict['usage']) if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) if not scratchpad.action: # failed to extract action, return final answer directly @@ -255,7 +262,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): observation=answer, answer=answer, messages_ids=[]) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) else: # invoke tool error_response = None @@ -282,7 +291,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) message_file_ids = [message_file.id for message_file, _ in message_files] except ToolProviderCredentialValidationError as e: @@ -318,7 +329,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=scratchpad.agent_response, messages_ids=message_file_ids, ) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool message for prompt_tool in prompt_messages_tools: @@ -352,7 +365,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish_message_end(LLMResult( + self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( @@ -360,7 +373,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' - ), PublishFrom.APPLICATION_MANAGER) + )), PublishFrom.APPLICATION_MANAGER) def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ -> Generator[Union[str, dict], None, None]: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 30e5cdd6946a14..7c3849a12ca0c4 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,7 +4,8 @@ from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner -from core.app.app_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -135,7 +136,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): @@ -195,7 +198,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not result.message.content: result.message.content = '' - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) yield LLMResultChunk( model=model_instance.model, @@ -233,8 +238,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): messages_ids=[], llm_usage=current_llm_usage ) - - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) final_answer += response + '\n' @@ -275,7 +281,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) # publish message file - self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) # add message file ids message_file_ids.append(message_file.id) @@ -331,7 +339,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=None, messages_ids=message_file_ids ) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -341,15 +351,15 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish_message_end(LLMResult( + self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=final_answer, + content=final_answer ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' - ), PublishFrom.APPLICATION_MANAGER) + )), PublishFrom.APPLICATION_MANAGER) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ diff --git a/api/core/app/app_queue_manager.py b/api/core/app/app_queue_manager.py deleted file mode 100644 index 5655c8d979b4f2..00000000000000 --- a/api/core/app/app_queue_manager.py +++ /dev/null @@ -1,335 +0,0 @@ -import queue -import time -from collections.abc import Generator -from enum import Enum -from typing import Any - -from sqlalchemy.orm import DeclarativeMeta - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueAgentMessageEvent, - QueueAgentThoughtEvent, - QueueAnnotationReplyEvent, - QueueErrorEvent, - QueueLLMChunkEvent, - QueueMessage, - QueueMessageEndEvent, - QueueMessageFileEvent, - QueueMessageReplaceEvent, - QueueNodeFinishedEvent, - QueueNodeStartedEvent, - QueuePingEvent, - QueueRetrieverResourcesEvent, - QueueStopEvent, - QueueTextChunkEvent, - QueueWorkflowFinishedEvent, - QueueWorkflowStartedEvent, -) -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from extensions.ext_redis import redis_client -from models.model import MessageAgentThought, MessageFile - - -class PublishFrom(Enum): - APPLICATION_MANAGER = 1 - TASK_PIPELINE = 2 - - -class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: - if not user_id: - raise ValueError("user is required") - - self._task_id = task_id - self._user_id = user_id - self._invoke_from = invoke_from - self._conversation_id = str(conversation_id) - self._app_mode = app_mode - self._message_id = str(message_id) - - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") - - q = queue.Queue() - - self._q = q - - def listen(self) -> Generator: - """ - Listen to queue - :return: - """ - # wait for 10 minutes to stop listen - listen_timeout = 600 - start_time = time.time() - last_ping_time = 0 - - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break - - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE - ) - self.stop_listen() - - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 - - def stop_listen(self) -> None: - """ - Stop listen to queue - :return: - """ - self._q.put(None) - - def publish_llm_chunk(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: - """ - Publish llm chunk to channel - - :param chunk: llm chunk - :param pub_from: publish from - :return: - """ - self.publish(QueueLLMChunkEvent( - chunk=chunk - ), pub_from) - - def publish_text_chunk(self, text: str, pub_from: PublishFrom) -> None: - """ - Publish text chunk to channel - - :param text: text - :param pub_from: publish from - :return: - """ - self.publish(QueueTextChunkEvent( - text=text - ), pub_from) - - def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: - """ - Publish agent chunk message to channel - - :param chunk: chunk - :param pub_from: publish from - :return: - """ - self.publish(QueueAgentMessageEvent( - chunk=chunk - ), pub_from) - - def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None: - """ - Publish message replace - :param text: text - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageReplaceEvent( - text=text - ), pub_from) - - def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None: - """ - Publish retriever resources - :return: - """ - self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from) - - def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None: - """ - Publish annotation reply - :param message_annotation_id: message annotation id - :param pub_from: publish from - :return: - """ - self.publish(QueueAnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) - - def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None: - """ - Publish message end - :param llm_result: llm result - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from) - self.stop_listen() - - def publish_workflow_started(self, workflow_run_id: str, pub_from: PublishFrom) -> None: - """ - Publish workflow started - :param workflow_run_id: workflow run id - :param pub_from: publish from - :return: - """ - self.publish(QueueWorkflowStartedEvent(workflow_run_id=workflow_run_id), pub_from) - - def publish_workflow_finished(self, workflow_run_id: str, pub_from: PublishFrom) -> None: - """ - Publish workflow finished - :param workflow_run_id: workflow run id - :param pub_from: publish from - :return: - """ - self.publish(QueueWorkflowFinishedEvent(workflow_run_id=workflow_run_id), pub_from) - - def publish_node_started(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: - """ - Publish node started - :param workflow_node_execution_id: workflow node execution id - :param pub_from: publish from - :return: - """ - self.publish(QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) - - def publish_node_finished(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: - """ - Publish node finished - :param workflow_node_execution_id: workflow node execution id - :param pub_from: publish from - :return: - """ - self.publish(QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) - - def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None: - """ - Publish agent thought - :param message_agent_thought: message agent thought - :param pub_from: publish from - :return: - """ - self.publish(QueueAgentThoughtEvent( - agent_thought_id=message_agent_thought.id - ), pub_from) - - def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None: - """ - Publish agent thought - :param message_file: message file - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), pub_from) - - def publish_error(self, e, pub_from: PublishFrom) -> None: - """ - Publish error - :param e: error - :param pub_from: publish from - :return: - """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) - self.stop_listen() - - def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: - """ - Publish event to queue - :param event: - :param pub_from: - :return: - """ - self._check_for_sqlalchemy_models(event.dict()) - - message = QueueMessage( - task_id=self._task_id, - message_id=self._message_id, - conversation_id=self._conversation_id, - app_mode=self._app_mode, - event=event - ) - - self._q.put(message) - - if isinstance(event, QueueStopEvent): - self.stop_listen() - - if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise ConversationTaskStoppedException() - - @classmethod - def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: - """ - Set task stop flag - :return: - """ - result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) - if result is None: - return - - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": - return - - stopped_cache_key = cls._generate_stopped_cache_key(task_id) - redis_client.setex(stopped_cache_key, 600, 1) - - def _is_stopped(self) -> bool: - """ - Check if task is stopped - :return: - """ - stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) - result = redis_client.get(stopped_cache_key) - if result is not None: - return True - - return False - - @classmethod - def _generate_task_belong_cache_key(cls, task_id: str) -> str: - """ - Generate task belong cache key - :param task_id: task id - :return: - """ - return f"generate_task_belong:{task_id}" - - @classmethod - def _generate_stopped_cache_key(cls, task_id: str) -> str: - """ - Generate stopped cache key - :param task_id: task id - :return: - """ - return f"generate_task_stopped:{task_id}" - - def _check_for_sqlalchemy_models(self, data: Any): - # from entity to dict or list - if isinstance(data, dict): - for key, value in data.items(): - self._check_for_sqlalchemy_models(value) - elif isinstance(data, list): - for item in data: - self._check_for_sqlalchemy_models(item) - else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") - - -class ConversationTaskStoppedException(Exception): - pass diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 937f95679a43a1..a19a5c8f6763f2 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -8,11 +8,12 @@ from pydantic import ValidationError from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -101,7 +102,7 @@ def generate(self, app_model: App, ) = self._init_generate_records(application_generate_entity, conversation) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c5ffa8016567ac..8fff8fc37ea0d3 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -2,14 +2,14 @@ import time from typing import cast -from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, ) -from core.app.entities.queue_entities import QueueStopEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable @@ -93,7 +93,7 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)], + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) def handle_input_moderation(self, queue_manager: AppQueueManager, @@ -153,9 +153,9 @@ def handle_annotation_reply(self, app_record: App, ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) self._stream_output( @@ -182,7 +182,11 @@ def _stream_output(self, queue_manager: AppQueueManager, if stream: index = 0 for token in text: - queue_manager.publish_text_chunk(token, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.APPLICATION_MANAGER + ) index += 1 time.sleep(0.01) @@ -190,4 +194,3 @@ def _stream_output(self, queue_manager: AppQueueManager, QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER ) - queue_manager.stop_listen() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index cfeb46f05a2b80..84352f16c7505f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -46,6 +46,7 @@ class TaskState(BaseModel): """ answer: str = "" metadata: dict = {} + usage: LLMUsage class AdvancedChatAppGenerateTaskPipeline: @@ -349,7 +350,12 @@ def _process_stream_response(self) -> Generator: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + self._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE @@ -558,5 +564,5 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config ), - on_message_replace_func=self._queue_manager.publish_message_replace + queue_manager=self._queue_manager ) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index d5dbdf0dd2cb05..6d27620a0986fb 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -9,10 +9,11 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -119,7 +120,7 @@ def generate(self, app_model: App, ) = self._init_generate_records(application_generate_entity, conversation) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 27a473fb17e3d0..2e142c63f1ff47 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -4,10 +4,11 @@ from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentEntity from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity +from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -120,10 +121,11 @@ def run(self, application_generate_entity: AgentChatAppGenerateEntity, ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) + self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py new file mode 100644 index 00000000000000..0391599040c97d --- /dev/null +++ b/api/core/app/apps/base_app_queue_manager.py @@ -0,0 +1,181 @@ +import queue +import time +from abc import abstractmethod +from collections.abc import Generator +from enum import Enum +from typing import Any + +from sqlalchemy.orm import DeclarativeMeta + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessage, + QueueMessageEndEvent, + QueuePingEvent, + QueueStopEvent, +) +from extensions.ext_redis import redis_client + + +class PublishFrom(Enum): + APPLICATION_MANAGER = 1 + TASK_PIPELINE = 2 + + +class AppQueueManager: + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom) -> None: + if not user_id: + raise ValueError("user is required") + + self._task_id = task_id + self._user_id = user_id + self._invoke_from = invoke_from + + user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + + q = queue.Queue() + + self._q = q + + def listen(self) -> Generator: + """ + Listen to queue + :return: + """ + # wait for 10 minutes to stop listen + listen_timeout = 600 + start_time = time.time() + last_ping_time = 0 + + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break + + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), + PublishFrom.TASK_PIPELINE + ) + + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + + def stop_listen(self) -> None: + """ + Stop listen to queue + :return: + """ + self._q.put(None) + + def publish_error(self, e, pub_from: PublishFrom) -> None: + """ + Publish error + :param e: error + :param pub_from: publish from + :return: + """ + self.publish(QueueErrorEvent( + error=e + ), pub_from) + + def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + self._check_for_sqlalchemy_models(event.dict()) + + message = self.construct_queue_message(event) + + self._q.put(message) + + if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise ConversationTaskStoppedException() + + @abstractmethod + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + raise NotImplementedError + + @classmethod + def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: + """ + Set task stop flag + :return: + """ + result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) + if result is None: + return + + user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' + if result.decode('utf-8') != f"{user_prefix}-{user_id}": + return + + stopped_cache_key = cls._generate_stopped_cache_key(task_id) + redis_client.setex(stopped_cache_key, 600, 1) + + def _is_stopped(self) -> bool: + """ + Check if task is stopped + :return: + """ + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) + result = redis_client.get(stopped_cache_key) + if result is not None: + return True + + return False + + @classmethod + def _generate_task_belong_cache_key(cls, task_id: str) -> str: + """ + Generate task belong cache key + :param task_id: task id + :return: + """ + return f"generate_task_belong:{task_id}" + + @classmethod + def _generate_stopped_cache_key(cls, task_id: str) -> str: + """ + Generate stopped cache key + :param task_id: task id + :return: + """ + return f"generate_task_stopped:{task_id}" + + def _check_for_sqlalchemy_models(self, data: Any): + # from entity to dict or list + if isinstance(data, dict): + for key, value in data.items(): + self._check_for_sqlalchemy_models(value) + elif isinstance(data, list): + for item in data: + self._check_for_sqlalchemy_models(item) + else: + if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): + raise TypeError("Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed.") + + +class ConversationTaskStoppedException(Exception): + pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index dda240d7789d65..e7ce7f25ef51a4 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -3,13 +3,14 @@ from typing import Optional, Union, cast from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AppGenerateEntity, EasyUIBasedAppGenerateEntity, InvokeFrom, ModelConfigWithCredentialsEntity, ) +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch @@ -187,25 +188,32 @@ def direct_output(self, queue_manager: AppQueueManager, if stream: index = 0 for token in text: - queue_manager.publish_llm_chunk(LLMResultChunk( + chunk = LLMResultChunk( model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage(content=token) ) - ), PublishFrom.APPLICATION_MANAGER) + ) + + queue_manager.publish( + QueueLLMChunkEvent( + chunk=chunk + ), PublishFrom.APPLICATION_MANAGER + ) index += 1 time.sleep(0.01) - queue_manager.publish_message_end( - llm_result=LLMResult( - model=app_generate_entity.model_config.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() - ), - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=app_generate_entity.model_config.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=text), + usage=usage if usage else LLMUsage.empty_usage() + ), + ), PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], @@ -241,9 +249,10 @@ def _handle_invoke_result_direct(self, invoke_result: LLMResult, :param queue_manager: application queue manager :return: """ - queue_manager.publish_message_end( - llm_result=invoke_result, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=invoke_result, + ), PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result_stream(self, invoke_result: Generator, @@ -261,9 +270,17 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, usage = None for result in invoke_result: if not agent: - queue_manager.publish_llm_chunk(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueLLMChunkEvent( + chunk=result + ), PublishFrom.APPLICATION_MANAGER + ) else: - queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueAgentMessageEvent( + chunk=result + ), PublishFrom.APPLICATION_MANAGER + ) text += result.delta.message.content @@ -286,9 +303,10 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, usage=usage ) - queue_manager.publish_message_end( - llm_result=llm_result, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=llm_result, + ), PublishFrom.APPLICATION_MANAGER ) def moderation_for_inputs(self, app_id: str, @@ -311,7 +329,7 @@ def moderation_for_inputs(self, app_id: str, tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query, + query=query if query else '' ) def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 978ac9656b64a5..7ddf8dfe32b71a 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -9,10 +9,11 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -119,7 +120,7 @@ def generate(self, app_model: App, ) = self._init_generate_records(application_generate_entity, conversation) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index bce4606f21ba26..d51f3db5409eec 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,12 +1,13 @@ import logging from typing import cast -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.entities.app_invoke_entities import ( ChatAppGenerateEntity, ) +from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -117,10 +118,11 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) + self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 9355bae12381f8..7150bee3cefcfe 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -9,10 +9,11 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -112,7 +113,7 @@ def generate(self, app_model: App, ) = self._init_generate_records(application_generate_entity) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, @@ -263,7 +264,7 @@ def generate_more_like_this(self, app_model: App, ) = self._init_generate_records(application_generate_entity) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index d67d485e1d7d89..04adf77be5b2c4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,7 +1,7 @@ import logging from typing import cast -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.entities.app_invoke_entities import ( diff --git a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 80596668b8680f..856bfb623d0e22 100644 --- a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ChatAppGenerateEntity, @@ -385,14 +385,19 @@ def _process_stream_response(self) -> Generator: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_llm_chunk(LLMResultChunk( - model=self._task_state.llm_result.model, - prompt_messages=self._task_state.llm_result.prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) - ), PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueLLMChunkEvent( + chunk=LLMResultChunk( + model=self._task_state.llm_result.model, + prompt_messages=self._task_state.llm_result.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) + ) + ) + ), PublishFrom.TASK_PIPELINE + ) + self._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE @@ -664,5 +669,5 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config ), - on_message_replace_func=self._queue_manager.publish_message_replace + queue_manager=self._queue_manager ) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index dab72bd6d6cf7c..3dee68b5e1db59 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -6,8 +6,8 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py new file mode 100644 index 00000000000000..ed9475502d465f --- /dev/null +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -0,0 +1,29 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueMessage, +) + + +class MessageBasedAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom, + conversation_id: str, + app_mode: str, + message_id: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._conversation_id = str(conversation_id) + self._app_mode = app_mode + self._message_id = str(message_id) + + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + return QueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event + ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py new file mode 100644 index 00000000000000..891ca4c2bedcee --- /dev/null +++ b/api/core/app/apps/workflow/app_generator.py @@ -0,0 +1,164 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowAppGenerator(BaseAppGenerator): + def generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + inputs = args['inputs'] + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs=self._get_cleaned_inputs(inputs, app_config), + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from + ) + + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :return: + """ + with flask_app.app_context(): + try: + # workflow app + runner = WorkflowAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py new file mode 100644 index 00000000000000..0f9b0a1c78722e --- /dev/null +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -0,0 +1,23 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueMessage, +) + + +class WorkflowAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom, + app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + return QueueMessage( + task_id=self._task_id, + app_mode=self._app_mode, + event=event + ) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py new file mode 100644 index 00000000000000..e675026e41995c --- /dev/null +++ b/api/core/app/apps/workflow/app_runner.py @@ -0,0 +1,156 @@ +import logging +import time +from typing import cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent +from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.moderation.base import ModerationException +from core.moderation.input_moderation import InputModeration +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +class WorkflowAppRunner: + """ + Workflow Application Runner + """ + + def run(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(WorkflowAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + inputs = application_generate_entity.inputs + files = application_generate_entity.files + + # moderation + if self.handle_input_moderation( + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs + ): + return + + # fetch user + if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: + user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() + else: + user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + + # RUN WORKFLOW + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.run_workflow( + workflow=workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, + user=user, + user_inputs=inputs, + system_inputs={ + SystemVariable.FILES: files + }, + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + ) + + def handle_input_moderation(self, queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: WorkflowAppGenerateEntity, + inputs: dict) -> bool: + """ + Handle input moderation + :param queue_manager: application queue manager + :param app_record: app record + :param app_generate_entity: application generate entity + :param inputs: inputs + :return: + """ + try: + # process sensitive_word_avoidance + moderation_feature = InputModeration() + _, inputs, query = moderation_feature.check( + app_id=app_record.id, + tenant_id=app_generate_entity.app_config.tenant_id, + app_config=app_generate_entity.app_config, + inputs=inputs, + query='' + ) + except ModerationException as e: + if app_generate_entity.stream: + self._stream_output( + queue_manager=queue_manager, + text=str(e), + ) + + queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION), + PublishFrom.APPLICATION_MANAGER + ) + return True + + return False + + def _stream_output(self, queue_manager: AppQueueManager, + text: str) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param text: text + :return: + """ + index = 0 + for token in text: + queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.APPLICATION_MANAGER + ) + index += 1 + time.sleep(0.01) + + def moderation_for_inputs(self, app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: dict) -> tuple[bool, dict, str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_generate_entity: app generate entity + :param inputs: inputs + :return: + """ + moderation_feature = InputModeration() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_generate_entity.app_config, + inputs=inputs, + query='' + ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py new file mode 100644 index 00000000000000..df83ad634eecc8 --- /dev/null +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -0,0 +1,408 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from pydantic import BaseModel + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus + +logger = logging.getLogger(__name__) + + +class TaskState(BaseModel): + """ + TaskState entity + """ + answer: str = "" + metadata: dict = {} + workflow_run_id: Optional[str] = None + + +class WorkflowAppGenerateTaskPipeline: + """ + WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + stream: bool) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + """ + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._task_state = TaskState() + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + self._stream = stream + + def process(self) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + if self._stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueStopEvent): + workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) + else: + workflow_run = self._get_workflow_run(event.workflow_run_id) + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + self._task_state.workflow_run_id = event.workflow_run_id + + workflow_run = self._get_workflow_run(event.workflow_run_id) + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueStopEvent): + workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) + else: + workflow_run = self._get_workflow_run(event.workflow_run_id) + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'text_replace', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + 'data': { + 'text': self._task_state.answer + } + } + + yield self._yield_response(replace_response) + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + yield self._yield_response(workflow_run_response) + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.chunk_text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'text_replace', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + 'data': { + 'text': event.text + } + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Get workflow run. + :param workflow_run_id: workflow run id + :return: + """ + return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: + """ + Get workflow node execution. + :param workflow_node_execution_id: workflow node execution id + :return: + """ + return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'text_chunk', + 'workflow_run_id': self._task_state.workflow_run_id, + 'task_id': self._application_generate_entity.task_id, + 'data': { + 'text': text + } + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + **data + } + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + queue_manager=self._queue_manager + ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1c4f32b8f28da1..01cbd7d2b2df47 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -127,9 +127,9 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): query: Optional[str] = None -class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity): +class WorkflowAppGenerateEntity(AppGenerateEntity): """ - Workflow UI Based Application Generate Entity. + Workflow Application Generate Entity. """ # app config app_config: WorkflowUIBasedAppConfig diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index ca781a55bc9c90..8e1f496b226c14 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,7 @@ -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment @@ -82,4 +83,7 @@ def return_retriever_resource_info(self, resource: list): db.session.add(dataset_retriever_resource) db.session.commit() - self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER) + self._queue_manager.publish( + QueueRetrieverResourcesEvent(retriever_resources=resource), + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py index 80dabc75489738..f8bad94252444f 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -1,4 +1,11 @@ -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from models.workflow import WorkflowNodeExecution, WorkflowRun @@ -12,43 +19,45 @@ def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ Workflow run started """ - self._queue_manager.publish_workflow_started( - workflow_run_id=workflow_run.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER ) def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: """ Workflow run finished """ - self._queue_manager.publish_workflow_finished( - workflow_run_id=workflow_run.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER ) def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: """ Workflow node execute started """ - self._queue_manager.publish_node_started( - workflow_node_execution_id=workflow_node_execution.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER ) def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: """ Workflow node execute finished """ - self._queue_manager.publish_node_finished( - workflow_node_execution_id=workflow_node_execution.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER ) + def on_text_chunk(self, text: str) -> None: """ Publish text chunk """ - self._queue_manager.publish_text_chunk( - text=text, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 749ee431e8f319..af8910614da0cd 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -6,7 +6,8 @@ from flask import Flask, current_app from pydantic import BaseModel -from core.app.app_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import QueueMessageReplaceEvent from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -25,7 +26,7 @@ class OutputModeration(BaseModel): app_id: str rule: ModerationRule - on_message_replace_func: Any + queue_manager: AppQueueManager thread: Optional[threading.Thread] = None thread_running: bool = True @@ -67,7 +68,12 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> final_output = result.text if public_event: - self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output + ), + PublishFrom.TASK_PIPELINE + ) return final_output @@ -117,7 +123,12 @@ def worker(self, flask_app: Flask, buffer_size: int): # trigger replace event if self.thread_running: - self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output + ), + PublishFrom.TASK_PIPELINE + ) if result.action == ModerationAction.DIRECT_OUTPUT: break diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2c1b6eb819cb38..144d136bdc7f01 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,6 +6,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeType from core.workflow.workflow_engine_manager import WorkflowEngineManager @@ -175,8 +176,24 @@ def run_draft_workflow(self, app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom) -> Union[dict, Generator]: - # TODO - pass + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow + app_generator = WorkflowAppGenerator() + response = app_generator.generate( + app_model=app_model, + workflow=draft_workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=True + ) + + return response def convert_to_workflow(self, app_model: App, account: Account) -> App: """ From 079cc082a36252b841735952530aace430ec6ff1 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 09:55:29 +0800 Subject: [PATCH 076/160] use callback to filter workflow stream output --- api/core/app/apps/advanced_chat/app_runner.py | 7 +- .../workflow_event_trigger_callback.py | 41 +++++++-- api/core/app/apps/workflow/app_runner.py | 7 +- .../workflow_event_trigger_callback.py | 87 +++++++++++++++++++ .../callbacks/base_workflow_callback.py | 6 +- api/core/workflow/nodes/base_node.py | 11 +-- api/core/workflow/workflow_engine_manager.py | 36 -------- 7 files changed, 138 insertions(+), 57 deletions(-) rename api/core/{callback_handler => app/apps/advanced_chat}/workflow_event_trigger_callback.py (55%) create mode 100644 api/core/app/apps/workflow/workflow_event_trigger_callback.py diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 8fff8fc37ea0d3..077f0c2de0ed2a 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,7 @@ from typing import cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( @@ -10,7 +11,6 @@ InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent -from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager @@ -93,7 +93,10 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py similarity index 55% rename from api/core/callback_handler/workflow_event_trigger_callback.py rename to api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index f8bad94252444f..44fb5905b0fad9 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -7,13 +7,15 @@ QueueWorkflowStartedEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun class WorkflowEventTriggerCallback(BaseWorkflowCallback): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ @@ -51,13 +53,34 @@ def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNod PublishFrom.APPLICATION_MANAGER ) - - def on_text_chunk(self, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) + if node_id in self._streamable_node_ids: + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER + ) + + def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param graph: workflow graph + :return: + """ + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index e675026e41995c..132282ffe3f6e3 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,13 +4,13 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent -from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable @@ -76,7 +76,10 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, system_inputs={ SystemVariable.FILES: files }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py new file mode 100644 index 00000000000000..57775f2ccebd83 --- /dev/null +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -0,0 +1,87 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): + self._queue_manager = queue_manager + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + self._queue_manager.publish( + QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + self._queue_manager.publish( + QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish( + QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + self._queue_manager.publish( + QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_node_text_chunk(self, node_id: str, text: str) -> None: + """ + Publish text chunk + """ + if node_id in self._streamable_node_ids: + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER + ) + + def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param graph: workflow graph + :return: + """ + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 3425b2b03c7111..3866bf2c1518eb 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,9 +1,9 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from models.workflow import WorkflowNodeExecution, WorkflowRun -class BaseWorkflowCallback: +class BaseWorkflowCallback(ABC): @abstractmethod def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ @@ -33,7 +33,7 @@ def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNod raise NotImplementedError @abstractmethod - def on_text_chunk(self, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index efffdfae1aeeba..1ff05f9f4e7d28 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -16,7 +16,6 @@ class BaseNode: node_data: BaseNodeData node_run_result: Optional[NodeRunResult] = None - stream_output_supported: bool = False callbacks: list[BaseWorkflowCallback] def __init__(self, config: dict, @@ -71,10 +70,12 @@ def publish_text_chunk(self, text: str) -> None: :param text: chunk text :return: """ - if self.stream_output_supported: - if self.callbacks: - for callback in self.callbacks: - callback.on_text_chunk(text) + if self.callbacks: + for callback in self.callbacks: + callback.on_node_text_chunk( + node_id=self.node_id, + text=text + ) @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 908b6849309105..4d881d3d045d24 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -32,7 +32,6 @@ WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, - WorkflowType, ) node_classes = { @@ -171,9 +170,6 @@ def run_workflow(self, workflow: Workflow, ) ) - # fetch predecessor node ids before end node (include: llm, direct answer) - streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph) - try: predecessor_node = None while True: @@ -187,10 +183,6 @@ def run_workflow(self, workflow: Workflow, if not next_node: break - # check if node is streamable - if next_node.node_id in streamable_node_ids: - next_node.stream_output_supported = True - # max steps 30 reached if len(workflow_run_state.workflow_node_executions) > 30: raise ValueError('Max steps 30 reached.') @@ -233,34 +225,6 @@ def run_workflow(self, workflow: Workflow, callbacks=callbacks ) - def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param workflow: Workflow instance - :param graph: workflow graph - :return: - """ - workflow_type = WorkflowType.value_of(workflow.type) - - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: - if workflow_type == WorkflowType.WORKFLOW: - if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': - end_node_ids.append(node_config.get('id')) - else: - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids - def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], From 3e54cb26beee1c23c31c8eaa2f01ef32a9e8f471 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 10:09:23 +0800 Subject: [PATCH 077/160] move funcs --- api/core/workflow/workflow_engine_manager.py | 25 -------------------- api/services/workflow_service.py | 14 +++++++---- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 4d881d3d045d24..8ab0eb4802c913 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -51,30 +51,6 @@ class WorkflowEngineManager: - def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: - """ - Get draft workflow - """ - # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() - - # return draft workflow - return workflow - - def get_published_workflow(self, app_model: App) -> Optional[Workflow]: - """ - Get published workflow - """ - if not app_model.workflow_id: - return None - - # fetch published workflow by workflow_id - return self.get_workflow(app_model, app_model.workflow_id) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: """ Get workflow @@ -404,7 +380,6 @@ def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: :param max_execution_time: max execution time :return: """ - # TODO check queue is stopped return time.perf_counter() - start_at > max_execution_time def _run_workflow_node(self, workflow_run_state: WorkflowRunState, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 144d136bdc7f01..833c22cdffaa01 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -26,22 +26,28 @@ def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow """ - workflow_engine_manager = WorkflowEngineManager() + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() # return draft workflow - return workflow_engine_manager.get_draft_workflow(app_model=app_model) + return workflow def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ Get published workflow """ + if not app_model.workflow_id: return None workflow_engine_manager = WorkflowEngineManager() - # return published workflow - return workflow_engine_manager.get_published_workflow(app_model=app_model) + # fetch published workflow by workflow_id + return workflow_engine_manager.get_workflow(app_model, app_model.workflow_id) def sync_draft_workflow(self, app_model: App, graph: dict, From 8684b172d201ef9414a6dff756f42f5439f809f0 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 15:43:55 +0800 Subject: [PATCH 078/160] add start, end, direct answer node --- .../entities/base_node_data_entities.py | 2 - api/core/workflow/entities/node_entities.py | 13 ++++- .../workflow/entities/variable_entities.py | 9 +++ .../workflow/entities/workflow_entities.py | 7 ++- api/core/workflow/nodes/base_node.py | 4 +- .../nodes/direct_answer/direct_answer_node.py | 51 ++++++++++++++++- .../workflow/nodes/direct_answer/entities.py | 10 ++++ api/core/workflow/nodes/end/end_node.py | 57 ++++++++++++++++++- api/core/workflow/nodes/end/entities.py | 43 ++++++++++++++ api/core/workflow/nodes/llm/entities.py | 8 +++ api/core/workflow/nodes/llm/llm_node.py | 21 ++++++- api/core/workflow/nodes/start/entities.py | 16 +----- api/core/workflow/nodes/start/start_node.py | 56 ++++++++++++++++-- api/core/workflow/workflow_engine_manager.py | 8 ++- 14 files changed, 274 insertions(+), 31 deletions(-) create mode 100644 api/core/workflow/entities/variable_entities.py create mode 100644 api/core/workflow/nodes/direct_answer/entities.py create mode 100644 api/core/workflow/nodes/llm/entities.py diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index afa6ddff047a2e..fc6ee231ffc7ea 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -5,7 +5,5 @@ class BaseNodeData(ABC, BaseModel): - type: str - title: str desc: Optional[str] = None diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index af539692ef8eaf..263172da31b88e 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel @@ -46,6 +46,15 @@ class SystemVariable(Enum): CONVERSATION = 'conversation' +class NodeRunMetadataKey(Enum): + """ + Node Run Metadata Key. + """ + TOTAL_TOKENS = 'total_tokens' + TOTAL_PRICE = 'total_price' + CURRENCY = 'currency' + + class NodeRunResult(BaseModel): """ Node Run Result. @@ -55,7 +64,7 @@ class NodeRunResult(BaseModel): inputs: Optional[dict] = None # node inputs process_data: Optional[dict] = None # process data outputs: Optional[dict] = None # node outputs - metadata: Optional[dict] = None # node metadata + metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py new file mode 100644 index 00000000000000..19d9af2a6171a4 --- /dev/null +++ b/api/core/workflow/entities/variable_entities.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class VariableSelector(BaseModel): + """ + Variable Selector. + """ + variable: str + value_selector: list[str] diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 0d78e4c4f1fff5..8c15cb95cdced3 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -5,13 +5,18 @@ class WorkflowRunState: workflow_run: WorkflowRun start_at: float + user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 workflow_node_executions: list[WorkflowNodeExecution] = [] - def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None: + def __init__(self, workflow_run: WorkflowRun, + start_at: float, + user_inputs: dict, + variable_pool: VariablePool) -> None: self.workflow_run = workflow_run self.start_at = start_at + self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 1ff05f9f4e7d28..6720017d9f0814 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Optional from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback @@ -8,7 +8,7 @@ from models.workflow import WorkflowNodeExecutionStatus -class BaseNode: +class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index c6013974b810da..80ecdf77571ecc 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -1,5 +1,54 @@ +import time +from typing import Optional, cast + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData +from models.workflow import WorkflowNodeExecutionStatus class DirectAnswerNode(BaseNode): - pass + _node_data_cls = DirectAnswerNodeData + node_type = NodeType.DIRECT_ANSWER + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + if variable_pool is None and run_args: + raise ValueError("Not support single step debug.") + + variable_values = {} + for variable_selector in node_data.variables: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + variable_values[variable_selector.variable] = value + + # format answer template + template_parser = PromptTemplateParser(node_data.answer) + answer = template_parser.format(variable_values) + + # publish answer as stream + for word in answer: + self.publish_text_chunk(word) + time.sleep(0.01) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variable_values, + output={ + "answer": answer + } + ) diff --git a/api/core/workflow/nodes/direct_answer/entities.py b/api/core/workflow/nodes/direct_answer/entities.py new file mode 100644 index 00000000000000..e7c11e3c4d1d2e --- /dev/null +++ b/api/core/workflow/nodes/direct_answer/entities.py @@ -0,0 +1,10 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class DirectAnswerNodeData(BaseNodeData): + """ + DirectAnswer Node Data. + """ + variables: list[VariableSelector] = [] + answer: str diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index f9aea89af7cb3a..62429e3ac284fe 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,60 @@ +from typing import Optional, cast + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs +from models.workflow import WorkflowNodeExecutionStatus class EndNode(BaseNode): - pass + _node_data_cls = EndNodeData + node_type = NodeType.END + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + outputs_config = node_data.outputs + + if variable_pool is not None: + outputs = None + if outputs_config: + if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: + plain_text_selector = outputs_config.plain_text_selector + if plain_text_selector: + outputs = { + 'text': variable_pool.get_variable_value( + variable_selector=plain_text_selector, + target_value_type=ValueType.STRING + ) + } + else: + outputs = { + 'text': '' + } + elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: + structured_variables = outputs_config.structured_variables + if structured_variables: + outputs = {} + for variable_selector in structured_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + else: + outputs = {} + else: + raise ValueError("Not support single step debug.") + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=outputs, + outputs=outputs + ) diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index 045e7effc49711..32212ae7faccbc 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,4 +1,10 @@ from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class EndNodeOutputType(Enum): @@ -23,3 +29,40 @@ def value_of(cls, value: str) -> 'OutputType': if output_type.value == value: return output_type raise ValueError(f'invalid output type value {value}') + + +class EndNodeDataOutputs(BaseModel): + """ + END Node Data Outputs. + """ + class OutputType(Enum): + """ + Output Types. + """ + NONE = 'none' + PLAIN_TEXT = 'plain-text' + STRUCTURED = 'structured' + + @classmethod + def value_of(cls, value: str) -> 'OutputType': + """ + Get value of given output type. + + :param value: output type value + :return: output type + """ + for output_type in cls: + if output_type.value == value: + return output_type + raise ValueError(f'invalid output type value {value}') + + type: OutputType = OutputType.NONE + plain_text_selector: Optional[list[str]] = None + structured_variables: Optional[list[VariableSelector]] = None + + +class EndNodeData(BaseNodeData): + """ + END Node Data. + """ + outputs: Optional[EndNodeDataOutputs] = None diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py new file mode 100644 index 00000000000000..bd499543d903aa --- /dev/null +++ b/api/core/workflow/nodes/llm/entities.py @@ -0,0 +1,8 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class LLMNodeData(BaseNodeData): + """ + LLM Node Data. + """ + pass diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 1c7277e942e99f..e3ae9fc00f5f18 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,9 +1,28 @@ -from typing import Optional +from typing import Optional, cast +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.llm.entities import LLMNodeData class LLMNode(BaseNode): + _node_data_cls = LLMNodeData + node_type = NodeType.LLM + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + pass + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 64687db042f1dc..0bd5f203bf72a5 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,23 +1,9 @@ from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType class StartNodeData(BaseNodeData): """ - - title (string) 节点标题 - - desc (string) optional 节点描述 - - type (string) 节点类型,固定为 start - - variables (array[object]) 表单变量列表 - - type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义) - - label (string) 控件展示标签名 - - variable (string) 变量 key - - max_length (int) 最大长度,适用于 text-input 和 paragraph - - default (string) optional 默认值 - - required (bool) optional是否必填,默认 false - - hint (string) optional 提示信息 - - options (array[string]) 选项值(仅 select 可用) + Start Node Data """ - type: str = NodeType.START.value - variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 74d854143603e2..ce04031b046fe6 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,9 +1,11 @@ -from typing import Optional +from typing import Optional, cast -from core.workflow.entities.node_entities import NodeType +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData +from models.workflow import WorkflowNodeExecutionStatus class StartNode(BaseNode): @@ -11,12 +13,58 @@ class StartNode(BaseNode): node_type = NodeType.START def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node :param variable_pool: variable pool :param run_args: run args :return: """ - pass + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + variables = node_data.variables + # Get cleaned inputs + cleaned_inputs = self._get_cleaned_inputs(variables, run_args) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=cleaned_inputs, + outputs=cleaned_inputs + ) + + def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"Input form variable {variable} is required") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 8ab0eb4802c913..5423546957ecca 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional, Union +from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -141,6 +142,7 @@ def run_workflow(self, workflow: Workflow, workflow_run_state = WorkflowRunState( workflow_run=workflow_run, start_at=time.perf_counter(), + user_inputs=user_inputs, variable_pool=VariablePool( system_variables=system_inputs, ) @@ -399,7 +401,9 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool + variable_pool=workflow_run_state.variable_pool, + run_args=workflow_run_state.user_inputs + if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: @@ -492,7 +496,7 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode workflow_node_execution.inputs = json.dumps(result.inputs) workflow_node_execution.process_data = json.dumps(result.process_data) workflow_node_execution.outputs = json.dumps(result.outputs) - workflow_node_execution.execution_metadata = json.dumps(result.metadata) + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() From 2ad9c76093aa1ccb7ceb4702a5bc2854c711897d Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 16:31:35 +0800 Subject: [PATCH 079/160] modify migrations --- ...5564d_conversation_columns_set_nullable.py | 48 +++++++++++++++++++ .../versions/b289e2408ee2_add_workflow.py | 2 - 2 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py new file mode 100644 index 00000000000000..f388b99b9068a0 --- /dev/null +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -0,0 +1,48 @@ +"""conversation columns set nullable + +Revision ID: 42e85ed5564d +Revises: f9107f83abab +Create Date: 2024-03-07 08:30:29.133614 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '42e85ed5564d' +down_revision = 'f9107f83abab' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 5ae1e65611ab3b..cf8530dc6782d2 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -78,8 +78,6 @@ def upgrade(): sa.Column('error', sa.Text(), nullable=True), sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), - sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), - sa.Column('currency', sa.String(length=255), nullable=True), sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False), From b174f852377e9c534cbc67c2dcd271364e487fc9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 17:15:46 +0800 Subject: [PATCH 080/160] fix bug --- api/controllers/console/app/workflow.py | 2 +- api/fields/app_fields.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4f8df6bcec5a34..5d70076821ac4b 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -65,7 +65,7 @@ def post(self, app_model: App): return { "result": "success", - "updated_at": TimestampField().format(workflow.updated_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) } diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 69ab1d3e3e8c3f..ccb95ad5731147 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -48,7 +48,7 @@ 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), 'created_at': TimestampField } @@ -68,7 +68,7 @@ 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), 'created_at': TimestampField } @@ -118,7 +118,7 @@ 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), 'site': fields.Nested(site_fields), 'api_base_url': fields.String, 'created_at': TimestampField, From 1f986a3abbef7ae2cbcbdf0cd05acebeb48baeca Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 19:45:02 +0800 Subject: [PATCH 081/160] fix bugs --- api/controllers/console/app/workflow.py | 28 +++-- .../advanced_chat/generate_task_pipeline.py | 2 +- .../workflow_event_trigger_callback.py | 2 +- api/core/app/apps/chat/app_config_manager.py | 2 +- .../workflow_event_trigger_callback.py | 2 +- api/core/workflow/workflow_engine_manager.py | 101 +++++++++--------- .../versions/b289e2408ee2_add_workflow.py | 4 +- ...29b71023c_messages_columns_set_nullable.py | 41 +++++++ api/models/model.py | 4 +- api/models/workflow.py | 6 +- 10 files changed, 119 insertions(+), 73 deletions(-) create mode 100644 api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5d70076821ac4b..8a68cafad884b2 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse @@ -79,9 +80,9 @@ def post(self, app_model: App): Run draft workflow """ parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument('inputs', type=dict, location='json') + parser.add_argument('query', type=str, required=True, location='json', default='') + parser.add_argument('files', type=list, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') args = parser.parse_args() @@ -93,6 +94,8 @@ def post(self, app_model: App): args=args, invoke_from=InvokeFrom.DEBUGGER ) + + return compact_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -103,12 +106,6 @@ def post(self, app_model: App): logging.exception("internal server error.") raise InternalServerError() - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - class DraftWorkflowRunApi(Resource): @setup_required @@ -120,7 +117,7 @@ def post(self, app_model: App): Run draft workflow """ parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() workflow_service = WorkflowService() @@ -280,6 +277,17 @@ def post(self, app_model: App): return workflow +def compact_response(response: Union[dict, Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 84352f16c7505f..624a0f430af33f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -174,7 +174,7 @@ def _process_stream_response(self) -> Generator: response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 44fb5905b0fad9..5d99ce6297013c 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index ac69a928231420..553cf34ee9b142 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -46,7 +46,7 @@ def get_app_config(cls, app_model: App, else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 57775f2ccebd83..3d7a4035e7fdef 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 5423546957ecca..05a784c221de25 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -5,7 +5,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.nodes.base_node import BaseNode @@ -122,10 +122,10 @@ def run_workflow(self, workflow: Workflow, if 'nodes' not in graph or 'edges' not in graph: raise ValueError('nodes or edges not found in workflow graph') - if isinstance(graph.get('nodes'), list): + if not isinstance(graph.get('nodes'), list): raise ValueError('nodes in workflow graph must be a list') - if isinstance(graph.get('edges'), list): + if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') # init workflow run @@ -150,6 +150,7 @@ def run_workflow(self, workflow: Workflow, try: predecessor_node = None + has_entry_node = False while True: # get next node, multiple target nodes in the future next_node = self._get_next_node( @@ -161,6 +162,8 @@ def run_workflow(self, workflow: Workflow, if not next_node: break + has_entry_node = True + # max steps 30 reached if len(workflow_run_state.workflow_node_executions) > 30: raise ValueError('Max steps 30 reached.') @@ -182,7 +185,7 @@ def run_workflow(self, workflow: Workflow, predecessor_node = next_node - if not predecessor_node and not next_node: + if not has_entry_node: self._workflow_run_failed( workflow_run_state=workflow_run_state, error='Start node not found in workflow graph.', @@ -219,38 +222,31 @@ def _init_workflow_run(self, workflow: Workflow, :param callbacks: workflow callbacks :return: """ - try: - db.session.begin() - - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .for_update() \ - .scalar() or 0 - new_sequence_number = max_sequence + 1 - - # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps({**user_inputs, **system_inputs}), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id - ) + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) - db.session.add(workflow_run) - db.session.commit() - except: - db.session.rollback() - raise + db.session.add(workflow_run) + db.session.commit() if callbacks: for callback in callbacks: @@ -330,7 +326,7 @@ def _get_next_node(self, graph: dict, if not predecessor_node: for node_config in nodes: - if node_config.get('type') == NodeType.START.value: + if node_config.get('data', {}).get('type', '') == NodeType.START.value: return StartNode(config=node_config) else: edges = graph.get('edges') @@ -368,7 +364,7 @@ def _get_next_node(self, graph: dict, return None # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) + target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) return target_node( config=target_node_config, @@ -424,17 +420,18 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, callbacks=callbacks ) - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) + if node_run_result.outputs: + for variable_key, variable_value in node_run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) - if node_run_result.metadata.get('total_tokens'): - workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) + if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) return workflow_node_execution @@ -464,7 +461,6 @@ def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRun node_id=node.node_id, node_type=node.node_type.value, title=node.node_data.title, - type=node.node_type.value, status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=workflow_run.created_by_role, created_by=workflow_run.created_by @@ -493,10 +489,11 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode """ workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.inputs = json.dumps(result.inputs) - workflow_node_execution.process_data = json.dumps(result.process_data) - workflow_node_execution.outputs = json.dumps(result.outputs) - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) + workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None + workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None + workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \ + if result.metadata else None workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index cf8530dc6782d2..8fadf2dc6c98c7 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -45,8 +45,8 @@ def upgrade(): sa.Column('node_id', sa.String(length=255), nullable=False), sa.Column('node_type', sa.String(length=255), nullable=False), sa.Column('title', sa.String(length=255), nullable=False), - sa.Column('inputs', sa.Text(), nullable=False), - sa.Column('process_data', sa.Text(), nullable=False), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('process_data', sa.Text(), nullable=True), sa.Column('outputs', sa.Text(), nullable=True), sa.Column('status', sa.String(length=255), nullable=False), sa.Column('error', sa.Text(), nullable=True), diff --git a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py new file mode 100644 index 00000000000000..ee81fdab2872a2 --- /dev/null +++ b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py @@ -0,0 +1,41 @@ +"""messages columns set nullable + +Revision ID: b5429b71023c +Revises: 42e85ed5564d +Create Date: 2024-03-07 09:52:00.846136 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b5429b71023c' +down_revision = '42e85ed5564d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index c579c3dee83399..6856c4e1b07d14 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -585,8 +585,8 @@ class Message(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) + model_provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) inputs = db.Column(db.JSON) diff --git a/api/models/workflow.py b/api/models/workflow.py index 032134a0d1c3ad..0883d0ef1321d4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -138,7 +138,7 @@ def user_input_form(self) -> list: if 'nodes' not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) + start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) if not start_node: return [] @@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model): node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text, nullable=False) - process_data = db.Column(db.Text, nullable=False) + inputs = db.Column(db.Text) + process_data = db.Column(db.Text) outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text) From 1914dfea7705c7d3d52059b52ab476941e745971 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 20:50:02 +0800 Subject: [PATCH 082/160] fix bugs --- .../advanced_chat/generate_task_pipeline.py | 24 ++++++++++++-- .../nodes/direct_answer/direct_answer_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 33 ++++++++++++++++++- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 624a0f430af33f..c1076fa947b1b5 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -47,6 +47,7 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} usage: LLMUsage + workflow_run_id: Optional[str] = None class AdvancedChatAppGenerateTaskPipeline: @@ -110,6 +111,8 @@ def _process_blocking_response(self) -> dict: } self._task_state.answer = annotation.content + elif isinstance(event, QueueWorkflowStartedEvent): + self._task_state.workflow_run_id = event.workflow_run_id elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: @@ -171,6 +174,7 @@ def _process_stream_response(self) -> Generator: break elif isinstance(event, QueueWorkflowStartedEvent): workflow_run = self._get_workflow_run(event.workflow_run_id) + self._task_state.workflow_run_id = workflow_run.id response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -234,7 +238,7 @@ def _process_stream_response(self) -> Generator: if isinstance(event, QueueWorkflowFinishedEvent): workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -389,7 +393,13 @@ def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: :param workflow_run_id: workflow run id :return: """ - return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + if workflow_run: + # Because the workflow_run will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_run) + return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: """ @@ -397,7 +407,14 @@ def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> Workf :param workflow_node_execution_id: workflow node execution id :return: """ - return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) + if workflow_node_execution: + # Because the workflow_node_execution will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_node_execution) + return workflow_node_execution def _save_message(self) -> None: """ @@ -408,6 +425,7 @@ def _save_message(self) -> None: self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.workflow_run_id = self._task_state.workflow_run_id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 80ecdf77571ecc..bc6e4bd8008aa6 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -48,7 +48,7 @@ def _run(self, variable_pool: Optional[VariablePool] = None, return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variable_values, - output={ + outputs={ "answer": answer } ) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 05a784c221de25..19dac76631ff44 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -33,6 +33,7 @@ WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, + WorkflowType, ) node_classes = { @@ -268,7 +269,7 @@ def _workflow_run_success(self, workflow_run_state: WorkflowRunState, # fetch last workflow_node_executions last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] if last_workflow_node_execution: - workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs) + workflow_run.outputs = last_workflow_node_execution.outputs workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at workflow_run.total_tokens = workflow_run_state.total_tokens @@ -390,6 +391,7 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, workflow_run_state=workflow_run_state, node=node, predecessor_node=predecessor_node, + callbacks=callbacks ) # add to workflow node executions @@ -412,6 +414,9 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, ) raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") + # set end node output if in chat + self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) + # node run success self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, @@ -529,6 +534,32 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE return workflow_node_execution + def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + node_run_result: NodeRunResult): + """ + Set end node output if in chat + :param workflow_run_state: workflow run state + :param node: current node + :param node_run_result: node run result + :return: + """ + if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2] + if workflow_node_execution_before_end: + if workflow_node_execution_before_end.node_type == NodeType.LLM.value: + if not node_run_result.outputs: + node_run_result.outputs = {} + + node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text') + elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value: + if not node_run_result.outputs: + node_run_result.outputs = {} + + node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer') + + return node_run_result + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], From 1a0b6adc2ced6860a477570d0d01b112fc9dd354 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 8 Mar 2024 16:44:42 +0800 Subject: [PATCH 083/160] fix stream bugs --- api/core/app/apps/advanced_chat/app_generator.py | 2 +- .../app/apps/advanced_chat/generate_task_pipeline.py | 2 +- .../advanced_chat/workflow_event_trigger_callback.py | 2 +- api/core/app/apps/base_app_queue_manager.py | 9 +++++++-- api/core/app/apps/workflow/generate_task_pipeline.py | 2 +- .../app/apps/workflow/workflow_event_trigger_callback.py | 2 +- api/core/app/entities/queue_entities.py | 2 +- 7 files changed, 13 insertions(+), 8 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a19a5c8f6763f2..92286c9af0ed02 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -54,7 +54,7 @@ def generate(self, app_model: App, inputs = args['inputs'] extras = { - "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False } # get conversation diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c1076fa947b1b5..9c06f516a5b361 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -346,7 +346,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueTextChunkEvent): - delta_text = event.chunk_text + delta_text = event.text if delta_text is None: continue diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 5d99ce6297013c..8f72305bb1fbc0 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -76,7 +76,7 @@ def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: streamable_node_ids = [] end_node_ids = [] for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('type') == NodeType.END.value: end_node_ids.append(node_config.get('id')) for edge_config in graph.get('edges'): diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 0391599040c97d..289567fe5dbb39 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -15,6 +15,7 @@ QueueMessageEndEvent, QueuePingEvent, QueueStopEvent, + QueueWorkflowFinishedEvent, ) from extensions.ext_redis import redis_client @@ -36,7 +37,8 @@ def __init__(self, task_id: str, self._invoke_from = invoke_from user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, + f"{user_prefix}-{self._user_id}") q = queue.Queue() @@ -106,7 +108,10 @@ def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: self._q.put(message) - if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent): + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowFinishedEvent): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index df83ad634eecc8..bcd5a4ba3df966 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -248,7 +248,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(workflow_run_response) elif isinstance(event, QueueTextChunkEvent): - delta_text = event.chunk_text + delta_text = event.text if delta_text is None: continue diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 3d7a4035e7fdef..12b93518ed55f0 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -76,7 +76,7 @@ def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: streamable_node_ids = [] end_node_ids = [] for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('type') == NodeType.END.value: if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': end_node_ids.append(node_config.get('id')) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index e5c6a8eff943dd..38f9638eaa387c 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -48,7 +48,7 @@ class QueueTextChunkEvent(AppQueueEvent): QueueTextChunkEvent entity """ event = QueueEvent.TEXT_CHUNK - chunk_text: str + text: str class QueueAgentMessageEvent(AppQueueEvent): From c152d55f68f1da84b56ed50e01072b16683eaea6 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 8 Mar 2024 18:37:08 +0800 Subject: [PATCH 084/160] fix workflow app bugs --- api/controllers/console/app/workflow.py | 8 +-- .../advanced_chat/generate_task_pipeline.py | 55 ++++++++++--------- .../apps/message_based_app_queue_manager.py | 3 +- .../app/apps/workflow/app_queue_manager.py | 3 +- .../apps/workflow/generate_task_pipeline.py | 34 ++++++++++-- api/core/app/entities/queue_entities.py | 17 +++++- api/models/workflow.py | 2 +- 7 files changed, 79 insertions(+), 43 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8a68cafad884b2..30d383ec025ce3 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -129,18 +129,14 @@ def post(self, app_model: App): args=args, invoke_from=InvokeFrom.DEBUGGER ) + + return compact_response(response) except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") raise InternalServerError() - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - class WorkflowTaskStopApi(Resource): @setup_required diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 9c06f516a5b361..db22607146a2d7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -235,36 +235,39 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueWorkflowFinishedEvent): + if isinstance(event, QueueStopEvent): + workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) + else: workflow_run = self._get_workflow_run(event.workflow_run_id) - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - data = self._error_to_stream_response_data(self._handle_error(err_event)) - yield self._yield_response(data) - break - workflow_run_response = { - 'event': 'workflow_finished', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'status': workflow_run.status, - 'outputs': workflow_run.outputs_dict, - 'error': workflow_run.error, - 'elapsed_time': workflow_run.elapsed_time, - 'total_tokens': workflow_run.total_tokens, - 'total_steps': workflow_run.total_steps, - 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) - } + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) } + } - yield self._yield_response(workflow_run_response) + yield self._yield_response(workflow_run_response) # response moderation if self._output_moderation_handler: diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index ed9475502d465f..13644c99ae8470 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -2,6 +2,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueMessage, ) @@ -20,7 +21,7 @@ def __init__(self, task_id: str, self._message_id = str(message_id) def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return QueueMessage( + return MessageQueueMessage( task_id=self._task_id, message_id=self._message_id, conversation_id=self._conversation_id, diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 0f9b0a1c78722e..5cf1e589132ef7 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -3,6 +3,7 @@ from core.app.entities.queue_entities import ( AppQueueEvent, QueueMessage, + WorkflowQueueMessage, ) @@ -16,7 +17,7 @@ def __init__(self, task_id: str, self._app_mode = app_mode def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return QueueMessage( + return WorkflowQueueMessage( task_id=self._task_id, app_mode=self._app_mode, event=event diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index bcd5a4ba3df966..a48640766a2d24 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -86,7 +86,7 @@ def _process_blocking_response(self) -> dict: workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) @@ -136,12 +136,11 @@ def _process_stream_response(self) -> Generator: break elif isinstance(event, QueueWorkflowStartedEvent): self._task_state.workflow_run_id = event.workflow_run_id - workflow_run = self._get_workflow_run(event.workflow_run_id) response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -198,7 +197,7 @@ def _process_stream_response(self) -> Generator: workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -228,6 +227,9 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(replace_response) + # save workflow app log + self._save_workflow_app_log() + workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, @@ -295,7 +297,13 @@ def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: :param workflow_run_id: workflow run id :return: """ - return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + if workflow_run: + # Because the workflow_run will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_run) + return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: """ @@ -303,7 +311,21 @@ def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> Workf :param workflow_node_execution_id: workflow node execution id :return: """ - return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) + if workflow_node_execution: + # Because the workflow_node_execution will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_node_execution) + return workflow_node_execution + + def _save_workflow_app_log(self) -> None: + """ + Save workflow app log. + :return: + """ + pass # todo def _handle_chunk(self, text: str) -> dict: """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 38f9638eaa387c..67ed13d7214645 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -176,7 +176,20 @@ class QueueMessage(BaseModel): QueueMessage entity """ task_id: str - message_id: str - conversation_id: str app_mode: str event: AppQueueEvent + + +class MessageQueueMessage(QueueMessage): + """ + MessageQueueMessage entity + """ + message_id: str + conversation_id: str + + +class WorkflowQueueMessage(QueueMessage): + """ + WorkflowQueueMessage entity + """ + pass diff --git a/api/models/workflow.py b/api/models/workflow.py index 0883d0ef1321d4..9768c364dd66ec 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -143,7 +143,7 @@ def user_input_form(self) -> list: return [] # get user_input_form from start node - return start_node.get('variables', []) + return start_node.get('data', {}).get('variables', []) class WorkflowRunTriggeredFrom(Enum): From 736e386f15bba02e55b958682c17531eceda5ee6 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 8 Mar 2024 21:35:58 +0800 Subject: [PATCH 085/160] fix: bugs --- api/core/app/apps/agent_chat/app_config_manager.py | 2 +- api/core/app/apps/completion/app_config_manager.py | 2 +- api/services/completion_service.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 57214f924a6023..232211c18b9843 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -52,7 +52,7 @@ def get_app_config(cls, app_model: App, else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a82e68a337ab4c..b98a4c16aaf996 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -37,7 +37,7 @@ def get_app_config(cls, app_model: App, else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 4e3c4e19f649c6..eb31ccbb3bf1ed 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -30,16 +30,16 @@ def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from=invoke_from, stream=streaming ) - elif app_model.mode == AppMode.CHAT.value: - return ChatAppGenerator().generate( + elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + return AgentChatAppGenerator().generate( app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming ) - elif app_model.mode == AppMode.AGENT_CHAT.value: - return AgentChatAppGenerator().generate( + elif app_model.mode == AppMode.CHAT.value: + return ChatAppGenerator().generate( app_model=app_model, user=user, args=args, From cb02b1e12e316e6dfd0c995cc71b98b0f995adec Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 8 Mar 2024 23:52:51 +0800 Subject: [PATCH 086/160] feat: code --- api/.env.example | 4 + api/config.py | 7 +- api/core/workflow/nodes/code/code_executor.py | 70 +++++++ api/core/workflow/nodes/code/code_node.py | 180 +++++++++++++++++- api/core/workflow/nodes/code/entities.py | 19 ++ .../workflow/nodes/code/python_template.py | 55 ++++++ 6 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 api/core/workflow/nodes/code/code_executor.py create mode 100644 api/core/workflow/nodes/code/entities.py create mode 100644 api/core/workflow/nodes/code/python_template.py diff --git a/api/.env.example b/api/.env.example index 32d89d4287c599..4a3b1d65afdfc0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -132,3 +132,7 @@ SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= BATCH_UPLOAD_LIMIT=10 + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT= +CODE_EXECUTINO_API_KEY= diff --git a/api/config.py b/api/config.py index a978a099b9ed4d..a6bc731b820a4e 100644 --- a/api/config.py +++ b/api/config.py @@ -59,7 +59,9 @@ 'CAN_REPLACE_LOGO': 'False', 'ETL_TYPE': 'dify', 'KEYWORD_STORE': 'jieba', - 'BATCH_UPLOAD_LIMIT': 20 + 'BATCH_UPLOAD_LIMIT': 20, + 'CODE_EXECUTION_ENDPOINT': '', + 'CODE_EXECUTION_API_KEY': '' } @@ -293,6 +295,9 @@ def __init__(self): self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') + self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') + self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') diff --git a/api/core/workflow/nodes/code/code_executor.py b/api/core/workflow/nodes/code/code_executor.py new file mode 100644 index 00000000000000..3ecd7cfd8948e7 --- /dev/null +++ b/api/core/workflow/nodes/code/code_executor.py @@ -0,0 +1,70 @@ +from os import environ + +from httpx import post +from yarl import URL +from pydantic import BaseModel + +from core.workflow.nodes.code.python_template import PythonTemplateTransformer + +# Code Executor +CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '') +CODE_EXECUTION_API_KEY = environ.get('CODE_EXECUTION_API_KEY', '') + +class CodeExecutionException(Exception): + pass + +class CodeExecutionResponse(BaseModel): + class Data(BaseModel): + stdout: str + stderr: str + + code: int + message: str + data: Data + +class CodeExecutor: + @classmethod + def execute_code(cls, language: str, code: str, inputs: dict) -> dict: + """ + Execute code + :param language: code language + :param code: code + :param inputs: inputs + :return: + """ + runner = PythonTemplateTransformer.transform_caller(code, inputs) + + url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + headers = { + 'X-Api-Key': CODE_EXECUTION_API_KEY + } + data = { + 'language': language, + 'code': runner, + } + + try: + response = post(str(url), json=data, headers=headers) + if response.status_code == 503: + raise CodeExecutionException('Code execution service is unavailable') + elif response.status_code != 200: + raise Exception('Failed to execute code') + except CodeExecutionException as e: + raise e + except Exception: + raise CodeExecutionException('Failed to execute code') + + try: + response = response.json() + except: + raise CodeExecutionException('Failed to parse response') + + response = CodeExecutionResponse(**response) + + if response.code != 0: + raise CodeExecutionException(response.message) + + if response.data.stderr: + raise CodeExecutionException(response.data.stderr) + + return PythonTemplateTransformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 7e69f91d118d7d..dc69fdc84ac6fa 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,9 +1,23 @@ -from typing import Optional +from typing import Optional, cast, Union +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.code_executor import CodeExecutor, CodeExecutionException +from models.workflow import WorkflowNodeExecutionStatus +MAX_NUMBER = 2 ** 63 - 1 +MIN_NUMBER = -2 ** 63 +MAX_PRECISION = 20 +MAX_DEPTH = 5 +MAX_STRING_LENGTH = 1000 +MAX_STRING_ARRAY_LENGTH = 30 class CodeNode(BaseNode): + _node_data_cls = CodeNodeData + node_type = NodeType.CODE + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -62,3 +76,167 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: ] } } + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run code + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data: CodeNodeData = cast(self._node_data_cls, node_data) + + # SINGLE DEBUG NOT IMPLEMENTED YET + if variable_pool is None and run_args: + raise ValueError("Not support single step debug.") + + # Get code language + code_language = node_data.code_language + code = node_data.code + + # Get variables + variables = {} + for variable_selector in node_data.variables: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + variables[variable] = value + + # Run code + try: + result = CodeExecutor.execute_code( + language=code_language, + code=code, + inputs=variables + ) + except CodeExecutionException as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + # Transform result + result = self._transform_result(result, node_data.outputs) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=result + ) + + def _check_string(self, value: str, variable: str) -> str: + """ + Check string + :param value: value + :param variable: variable + :param max_length: max length + :return: + """ + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if len(value) > MAX_STRING_LENGTH: + raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') + + return value.replace('\x00', '') + + def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + """ + Check number + :param value: value + :param variable: variable + :return: + """ + if not isinstance(value, (int, float)): + raise ValueError(f"{variable} in input form must be a number") + + if value > MAX_NUMBER or value < MIN_NUMBER: + raise ValueError(f'{variable} in input form is out of range.') + + if isinstance(value, float): + value = round(value, MAX_PRECISION) + + return value + + def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + prefix: str = '', + depth: int = 1) -> dict: + """ + Transform result + :param result: result + :param output_schema: output schema + :return: + """ + if depth > MAX_DEPTH: + raise ValueError("Depth limit reached, object too deep.") + + transformed_result = {} + for output_name, output_config in output_schema.items(): + if output_config.type == 'object': + # check if output is object + if not isinstance(result.get(output_name), dict): + raise ValueError( + f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' + ) + + transformed_result[output_name] = self._transform_result( + result=result[output_name], + output_schema=output_config.children, + prefix=f'{prefix}.{output_name}' if prefix else output_name, + depth=depth + 1 + ) + elif output_config.type == 'number': + # check if number available + transformed_result[output_name] = self._check_number( + value=result[output_name], + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + + transformed_result[output_name] = result[output_name] + elif output_config.type == 'string': + # check if string available + transformed_result[output_name] = self._check_string( + value=result[output_name], + variable=f'{prefix}.{output_name}' if prefix else output_name, + ) + elif output_config.type == 'array[number]': + # check if array of number available + if not isinstance(result[output_name], list): + raise ValueError( + f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + + transformed_result[output_name] = [ + self._check_number( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + for value in result[output_name] + ] + elif output_config.type == 'array[string]': + # check if array of string available + if not isinstance(result[output_name], list): + raise ValueError( + f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + + if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: + raise ValueError( + f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' + ) + + transformed_result[output_name] = [ + self._check_string( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + for value in result[output_name] + ] + else: + raise ValueError(f'Output type {output_config.type} is not supported.') + + return transformed_result \ No newline at end of file diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py new file mode 100644 index 00000000000000..731b00f8c8a62b --- /dev/null +++ b/api/core/workflow/nodes/code/entities.py @@ -0,0 +1,19 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + +from pydantic import BaseModel +from typing import Literal, Union + +class CodeNodeData(BaseNodeData): + """ + Code Node Data. + """ + class Output(BaseModel): + type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] + children: Union[None, dict[str, 'Output']] + + variables: list[VariableSelector] + answer: str + code_language: str + code: str + outputs: dict[str, Output] diff --git a/api/core/workflow/nodes/code/python_template.py b/api/core/workflow/nodes/code/python_template.py new file mode 100644 index 00000000000000..03dfee36f3e488 --- /dev/null +++ b/api/core/workflow/nodes/code/python_template.py @@ -0,0 +1,55 @@ +import json +import re + +PYTHON_RUNNER = """# declare main function here +{{code}} + +# execute main function, and return the result +# inputs is a dict, and it +output = main(**{{inputs}}) + +# convert output to json and print +result = ''' +<> +{output} +<> +''' + +print(result) +""" + + +class PythonTemplateTransformer: + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = PYTHON_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + + result = result.group(1) + return json.loads(result) From 5596b3b00b0dbbc3658b70e16bc9b64bd27fa682 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 8 Mar 2024 23:53:18 +0800 Subject: [PATCH 087/160] fix: linter --- api/core/workflow/nodes/code/code_executor.py | 2 +- api/core/workflow/nodes/code/code_node.py | 8 ++++---- api/core/workflow/nodes/code/entities.py | 6 ++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/nodes/code/code_executor.py b/api/core/workflow/nodes/code/code_executor.py index 3ecd7cfd8948e7..058ee83d4646da 100644 --- a/api/core/workflow/nodes/code/code_executor.py +++ b/api/core/workflow/nodes/code/code_executor.py @@ -1,8 +1,8 @@ from os import environ from httpx import post -from yarl import URL from pydantic import BaseModel +from yarl import URL from core.workflow.nodes.code.python_template import PythonTemplateTransformer diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index dc69fdc84ac6fa..32f67768500821 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,10 +1,10 @@ -from typing import Optional, cast, Union +from typing import Optional, Union, cast + from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool - from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.code.code_executor import CodeExecutionException, CodeExecutor from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.code.code_executor import CodeExecutor, CodeExecutionException from models.workflow import WorkflowNodeExecutionStatus MAX_NUMBER = 2 ** 63 - 1 @@ -151,7 +151,7 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f :param variable: variable :return: """ - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): raise ValueError(f"{variable} in input form must be a number") if value > MAX_NUMBER or value < MIN_NUMBER: diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 731b00f8c8a62b..2212d77e2d4d88 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,8 +1,10 @@ +from typing import Literal, Union + +from pydantic import BaseModel + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector -from pydantic import BaseModel -from typing import Literal, Union class CodeNodeData(BaseNodeData): """ From fc573564b4f321233b2ddc1b3bf642c2834a7762 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 8 Mar 2024 23:59:09 +0800 Subject: [PATCH 088/160] refactor workflow runner --- api/controllers/console/app/workflow.py | 7 +- .../app/apps/advanced_chat/app_generator.py | 31 +- api/core/app/apps/advanced_chat/app_runner.py | 33 +- .../advanced_chat/generate_task_pipeline.py | 220 +++++++++--- .../workflow_event_trigger_callback.py | 83 ++++- api/core/app/apps/agent_chat/app_generator.py | 4 +- api/core/app/apps/base_app_queue_manager.py | 27 +- api/core/app/apps/chat/app_generator.py | 4 +- api/core/app/apps/completion/app_generator.py | 4 +- .../app/apps/message_based_app_generator.py | 4 +- .../apps/message_based_app_queue_manager.py | 35 +- api/core/app/apps/workflow/app_generator.py | 14 +- .../app/apps/workflow/app_queue_manager.py | 30 +- api/core/app/apps/workflow/app_runner.py | 33 +- .../apps/workflow/generate_task_pipeline.py | 207 +++++++++--- .../workflow_event_trigger_callback.py | 83 ++++- .../workflow_based_generate_task_pipeline.py | 202 +++++++++++ api/core/app/entities/queue_entities.py | 66 +++- .../callbacks/base_workflow_callback.py | 44 ++- .../workflow/entities/workflow_entities.py | 26 +- .../nodes/direct_answer/direct_answer_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 319 ++++-------------- api/services/workflow_service.py | 19 +- 23 files changed, 996 insertions(+), 501 deletions(-) create mode 100644 api/core/app/apps/workflow_based_generate_task_pipeline.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 30d383ec025ce3..5f03a7cd377744 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -147,9 +147,12 @@ def post(self, app_model: App, task_id: str): """ Stop workflow task """ - # TODO workflow_service = WorkflowService() - workflow_service.stop_workflow_task(app_model=app_model, task_id=task_id, account=current_user) + workflow_service.stop_workflow_task( + task_id=task_id, + user=current_user, + invoke_from=InvokeFrom.DEBUGGER + ) return { "result": "success" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 92286c9af0ed02..ed45e2ba8aab97 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -11,7 +11,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom @@ -123,11 +123,13 @@ def generate(self, app_model: App, worker_thread.start() # return response or stream generator - return self._handle_response( + return self._handle_advanced_chat_response( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, conversation=conversation, message=message, + user=user, stream=stream ) @@ -159,7 +161,7 @@ def _generate_worker(self, flask_app: Flask, conversation=conversation, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -177,33 +179,40 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.remove() - def _handle_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: + def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False) -> Union[dict, Generator]: """ Handle response. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager :param conversation: conversation :param message: message + :param user: account or end user :param stream: is stream :return: """ # init generate task pipeline generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, + user=user, + stream=stream ) try: - return generate_task_pipeline.process(stream=stream) + return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() + raise GenerateTaskStoppedException() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 077f0c2de0ed2a..3279e00355e915 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging import time -from typing import cast +from typing import Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback @@ -8,16 +8,14 @@ from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser, Message -from models.workflow import WorkflowRunTriggeredFrom +from models.model import App, Conversation, Message +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -46,7 +44,7 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, if not app_record: raise ValueError("App not found") - workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: raise ValueError("Workflow not initialized") @@ -74,19 +72,10 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, ): return - # fetch user - if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: - user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() - else: - user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, - user=user, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, @@ -99,6 +88,20 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, )] ) + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id + ).first() + + # return workflow + return workflow + def handle_input_moderation(self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index db22607146a2d7..18bc9c80080c6f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,9 +4,10 @@ from collections.abc import Generator from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Extra from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -16,25 +17,35 @@ QueueErrorEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageFile -from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus +from models.account import Account +from models.model import Conversation, EndUser, Message, MessageFile +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) @@ -47,41 +58,63 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} usage: LLMUsage - workflow_run_id: Optional[str] = None + + workflow_run: Optional[WorkflowRun] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 + + current_node_execution: Optional[WorkflowNodeExecution] = None + current_node_execution_start_at: Optional[float] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True -class AdvancedChatAppGenerateTaskPipeline: +class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, queue_manager: AppQueueManager, conversation: Conversation, - message: Message) -> None: + message: Message, + user: Union[Account, EndUser], + stream: bool) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager :param conversation: conversation :param message: message + :param user: user + :param stream: stream """ self._application_generate_entity = application_generate_entity + self._workflow = workflow self._queue_manager = queue_manager self._conversation = conversation self._message = message + self._user = user self._task_state = TaskState( usage=LLMUsage.empty_usage() ) self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() + self._stream = stream - def process(self, stream: bool) -> Union[dict, Generator]: + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ - if stream: + if self._stream: return self._process_stream_response() else: return self._process_blocking_response() @@ -112,22 +145,17 @@ def _process_blocking_response(self) -> dict: self._task_state.answer = annotation.content elif isinstance(event, QueueWorkflowStartedEvent): - self._task_state.workflow_run_id = event.workflow_run_id - elif isinstance(event, QueueNodeFinishedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) - if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueWorkflowFinishedEvent): - workflow_run = self._get_workflow_run(event.workflow_run_id) - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs - self._task_state.answer = outputs.get('text', '') - else: - raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + self._on_workflow_start() + elif isinstance(event, QueueNodeStartedEvent): + self._on_node_start(event) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run + + if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) # response moderation if self._output_moderation_handler: @@ -173,8 +201,9 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._get_workflow_run(event.workflow_run_id) - self._task_state.workflow_run_id = workflow_run.id + self._on_workflow_start() + workflow_run = self._task_state.workflow_run + response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -188,7 +217,9 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + self._on_node_start(event) + workflow_node_execution = self._task_state.current_node_execution + response = { 'event': 'node_started', 'task_id': self._application_generate_entity.task_id, @@ -204,8 +235,10 @@ def _process_stream_response(self) -> Generator: } yield self._yield_response(response) - elif isinstance(event, QueueNodeFinishedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + workflow_node_execution = self._task_state.current_node_execution + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict @@ -234,16 +267,11 @@ def _process_stream_response(self) -> Generator: } yield self._yield_response(response) - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueStopEvent): - workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) - else: - workflow_run = self._get_workflow_run(event.workflow_run_id) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: + if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) data = self._error_to_stream_response_data(self._handle_error(err_event)) yield self._yield_response(data) @@ -252,7 +280,7 @@ def _process_stream_response(self) -> Generator: workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -390,6 +418,102 @@ def _process_stream_response(self) -> Generator: else: continue + def _on_workflow_start(self) -> None: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs={ + SystemVariable.QUERY: self._message.query, + SystemVariable.FILES: self._application_generate_entity.files, + SystemVariable.CONVERSATION: self._conversation.id, + } + ) + + self._task_state.workflow_run = workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=self._task_state.workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + self._task_state.current_node_execution = workflow_node_execution + self._task_state.current_node_execution_start_at = time.perf_counter() + self._task_state.total_steps += 1 + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + + if workflow_node_execution.node_type == NodeType.LLM.value: + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + error=event.error + ) + + self._task_state.current_node_execution = workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + workflow_run = self._workflow_run_success( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=self._task_state.current_node_execution.outputs + if self._task_state.current_node_execution else None + ) + + self._task_state.workflow_run = workflow_run + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Get workflow run. @@ -397,11 +521,6 @@ def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: :return: """ workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - if workflow_run: - # Because the workflow_run will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_run) return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: @@ -412,11 +531,6 @@ def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> Workf """ workflow_node_execution = (db.session.query(WorkflowNodeExecution) .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - if workflow_node_execution: - # Because the workflow_node_execution will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_node_execution) return workflow_node_execution def _save_message(self) -> None: @@ -428,7 +542,7 @@ def _save_message(self) -> None: self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.workflow_run_id = self._task_state.workflow_run_id + self._message.workflow_run_id = self._task_state.workflow_run.id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 8f72305bb1fbc0..d9c8a2c96ddb6c 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -1,14 +1,19 @@ +from typing import Optional + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun +from models.workflow import Workflow class WorkflowEventTriggerCallback(BaseWorkflowCallback): @@ -17,39 +22,91 @@ def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) - def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_started(self) -> None: """ Workflow run started """ self._queue_manager.publish( - QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowStartedEvent(), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_succeeded(self) -> None: + """ + Workflow run succeeded + """ + self._queue_manager.publish( + QueueWorkflowSucceededEvent(), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_failed(self, error: str) -> None: """ - Workflow run finished + Workflow run failed """ self._queue_manager.publish( - QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowFailedEvent( + error=error + ), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: """ Workflow node execute started """ self._queue_manager.publish( - QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeStartedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + node_run_index=node_run_index, + predecessor_node_id=predecessor_node_id + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: + """ + Workflow node execute succeeded + """ + self._queue_manager.publish( + QueueNodeSucceededEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata + ), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str) -> None: """ - Workflow node execute finished + Workflow node execute failed """ self._queue_manager.publish( - QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeFailedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + error=error + ), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 6d27620a0986fb..700a340c969980 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -11,7 +11,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom @@ -177,7 +177,7 @@ def _generate_worker(self, flask_app: Flask, conversation=conversation, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 289567fe5dbb39..43a44819f9495e 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -11,11 +11,8 @@ from core.app.entities.queue_entities import ( AppQueueEvent, QueueErrorEvent, - QueueMessage, - QueueMessageEndEvent, QueuePingEvent, QueueStopEvent, - QueueWorkflowFinishedEvent, ) from extensions.ext_redis import redis_client @@ -103,22 +100,16 @@ def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: :return: """ self._check_for_sqlalchemy_models(event.dict()) - - message = self.construct_queue_message(event) - - self._q.put(message) - - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowFinishedEvent): - self.stop_listen() - - if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise ConversationTaskStoppedException() + self._publish(event, pub_from) @abstractmethod - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ raise NotImplementedError @classmethod @@ -182,5 +173,5 @@ def _check_for_sqlalchemy_models(self, data: Any): "that cause thread safety issues is not allowed.") -class ConversationTaskStoppedException(Exception): +class GenerateTaskStoppedException(Exception): pass diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 7ddf8dfe32b71a..317d045c043083 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -9,7 +9,7 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator @@ -177,7 +177,7 @@ def _generate_worker(self, flask_app: Flask, conversation=conversation, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 7150bee3cefcfe..b948938aac24ad 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -9,7 +9,7 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator @@ -166,7 +166,7 @@ def _generate_worker(self, flask_app: Flask, queue_manager=queue_manager, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 3dee68b5e1db59..0e76c96ff7fc6b 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -7,7 +7,7 @@ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, @@ -60,7 +60,7 @@ def _handle_response(self, application_generate_entity: Union[ return generate_task_pipeline.process(stream=stream) except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() + raise GenerateTaskStoppedException() else: logger.exception(e) raise e diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 13644c99ae8470..6d0a71f495e328 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,9 +1,14 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, MessageQueueMessage, + QueueErrorEvent, QueueMessage, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowSucceededEvent, ) @@ -28,3 +33,31 @@ def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: app_mode=self._app_mode, event=event ) + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = MessageQueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event + ) + + self._q.put(message) + + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedException() + diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 891ca4c2bedcee..d3303047caa250 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -9,7 +9,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -95,7 +95,9 @@ def generate(self, app_model: App, # return response or stream generator return self._handle_response( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, + user=user, stream=stream ) @@ -117,7 +119,7 @@ def _generate_worker(self, flask_app: Flask, application_generate_entity=application_generate_entity, queue_manager=queue_manager ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -136,19 +138,25 @@ def _generate_worker(self, flask_app: Flask, db.session.remove() def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, queue_manager: AppQueueManager, + user: Union[Account, EndUser], stream: bool = False) -> Union[dict, Generator]: """ Handle response. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager + :param user: account or end user :param stream: is stream :return: """ # init generate task pipeline generate_task_pipeline = WorkflowAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, + user=user, stream=stream ) @@ -156,7 +164,7 @@ def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntit return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() + raise GenerateTaskStoppedException() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 5cf1e589132ef7..f448138b53c0c2 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,8 +1,12 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, - QueueMessage, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowSucceededEvent, WorkflowQueueMessage, ) @@ -16,9 +20,27 @@ def __init__(self, task_id: str, self._app_mode = app_mode - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return WorkflowQueueMessage( + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage( task_id=self._task_id, app_mode=self._app_mode, event=event ) + + self._q.put(message) + + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedException() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 132282ffe3f6e3..59a385cb38c5a3 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,13 +1,12 @@ import logging import time -from typing import cast +from typing import Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, - InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent @@ -16,9 +15,8 @@ from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser -from models.workflow import WorkflowRunTriggeredFrom +from models.model import App +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -43,7 +41,7 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, if not app_record: raise ValueError("App not found") - workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: raise ValueError("Workflow not initialized") @@ -59,19 +57,10 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, ): return - # fetch user - if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: - user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() - else: - user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, - user=user, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files @@ -82,6 +71,20 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, )] ) + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id + ).first() + + # return workflow + return workflow + def handle_input_moderation(self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: WorkflowAppGenerateEntity, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a48640766a2d24..721124c4c54166 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,28 +4,35 @@ from collections.abc import Generator from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Extra from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( + InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import ( QueueErrorEvent, QueueMessageReplaceEvent, - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVariable from extensions.ext_database import db -from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus +from models.account import Account +from models.model import EndUser +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -36,24 +43,44 @@ class TaskState(BaseModel): """ answer: str = "" metadata: dict = {} - workflow_run_id: Optional[str] = None + workflow_run: Optional[WorkflowRun] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 -class WorkflowAppGenerateTaskPipeline: + current_node_execution: Optional[WorkflowNodeExecution] = None + current_node_execution_start_at: Optional[float] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + +class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, queue_manager: AppQueueManager, + user: Union[Account, EndUser], stream: bool) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager + :param user: user + :param stream: is stream """ self._application_generate_entity = application_generate_entity + self._workflow = workflow self._queue_manager = queue_manager + self._user = user self._task_state = TaskState() self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() @@ -79,17 +106,15 @@ def _process_blocking_response(self) -> dict: if isinstance(event, QueueErrorEvent): raise self._handle_error(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueStopEvent): - workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) - else: - workflow_run = self._get_workflow_run(event.workflow_run_id) - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: - raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + elif isinstance(event, QueueWorkflowStartedEvent): + self._on_workflow_start() + elif isinstance(event, QueueNodeStartedEvent): + self._on_node_start(event) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run # response moderation if self._output_moderation_handler: @@ -100,10 +125,12 @@ def _process_blocking_response(self) -> dict: public_event=False ) + # save workflow app log + self._save_workflow_app_log() + response = { - 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -135,8 +162,9 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._task_state.workflow_run_id = event.workflow_run_id - workflow_run = self._get_workflow_run(event.workflow_run_id) + self._on_workflow_start() + workflow_run = self._task_state.workflow_run + response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -150,7 +178,9 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + self._on_node_start(event) + workflow_node_execution = self._task_state.current_node_execution + response = { 'event': 'node_started', 'task_id': self._application_generate_entity.task_id, @@ -166,8 +196,10 @@ def _process_stream_response(self) -> Generator: } yield self._yield_response(response) - elif isinstance(event, QueueNodeFinishedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + workflow_node_execution = self._task_state.current_node_execution + response = { 'event': 'node_finished', 'task_id': self._application_generate_entity.task_id, @@ -190,20 +222,9 @@ def _process_stream_response(self) -> Generator: } yield self._yield_response(response) - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueStopEvent): - workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) - else: - workflow_run = self._get_workflow_run(event.workflow_run_id) - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - data = self._error_to_stream_response_data(self._handle_error(err_event)) - yield self._yield_response(data) - break + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run # response moderation if self._output_moderation_handler: @@ -219,7 +240,7 @@ def _process_stream_response(self) -> Generator: replace_response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, + 'workflow_run_id': self._task_state.workflow_run.id, 'data': { 'text': self._task_state.answer } @@ -233,7 +254,7 @@ def _process_stream_response(self) -> Generator: workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -244,7 +265,7 @@ def _process_stream_response(self) -> Generator: 'total_tokens': workflow_run.total_tokens, 'total_steps': workflow_run.total_steps, 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) + 'finished_at': int(workflow_run.finished_at.timestamp()) if workflow_run.finished_at else None } } @@ -279,7 +300,7 @@ def _process_stream_response(self) -> Generator: response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, + 'workflow_run_id': self._task_state.workflow_run.id, 'data': { 'text': event.text } @@ -291,6 +312,95 @@ def _process_stream_response(self) -> Generator: else: continue + def _on_workflow_start(self) -> None: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs={ + SystemVariable.FILES: self._application_generate_entity.files + } + ) + + self._task_state.workflow_run = workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=self._task_state.workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + self._task_state.current_node_execution = workflow_node_execution + self._task_state.current_node_execution_start_at = time.perf_counter() + self._task_state.total_steps += 1 + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + error=event.error + ) + + self._task_state.current_node_execution = workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + workflow_run = self._workflow_run_success( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=self._task_state.current_node_execution.outputs + if self._task_state.current_node_execution else None + ) + + self._task_state.workflow_run = workflow_run + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Get workflow run. @@ -298,11 +408,6 @@ def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: :return: """ workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - if workflow_run: - # Because the workflow_run will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_run) return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: @@ -313,11 +418,6 @@ def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> Workf """ workflow_node_execution = (db.session.query(WorkflowNodeExecution) .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - if workflow_node_execution: - # Because the workflow_node_execution will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_node_execution) return workflow_node_execution def _save_workflow_app_log(self) -> None: @@ -335,7 +435,7 @@ def _handle_chunk(self, text: str) -> dict: """ response = { 'event': 'text_chunk', - 'workflow_run_id': self._task_state.workflow_run_id, + 'workflow_run_id': self._task_state.workflow_run.id, 'task_id': self._application_generate_entity.task_id, 'data': { 'text': text @@ -398,7 +498,6 @@ def _error_to_stream_response_data(self, e: Exception) -> dict: return { 'event': 'error', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, **data } diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 12b93518ed55f0..318466711a3252 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -1,14 +1,19 @@ +from typing import Optional + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun +from models.workflow import Workflow class WorkflowEventTriggerCallback(BaseWorkflowCallback): @@ -17,39 +22,91 @@ def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) - def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_started(self) -> None: """ Workflow run started """ self._queue_manager.publish( - QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowStartedEvent(), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_succeeded(self) -> None: + """ + Workflow run succeeded + """ + self._queue_manager.publish( + QueueWorkflowSucceededEvent(), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_failed(self, error: str) -> None: """ - Workflow run finished + Workflow run failed """ self._queue_manager.publish( - QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowFailedEvent( + error=error + ), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: """ Workflow node execute started """ self._queue_manager.publish( - QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeStartedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + node_run_index=node_run_index, + predecessor_node_id=predecessor_node_id + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: + """ + Workflow node execute succeeded + """ + self._queue_manager.publish( + QueueNodeSucceededEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata + ), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str) -> None: """ - Workflow node execute finished + Workflow node execute failed """ self._queue_manager.publish( - QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeFailedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + error=error + ), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py new file mode 100644 index 00000000000000..3e9a7b9e1fc98e --- /dev/null +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -0,0 +1,202 @@ +import json +import time +from datetime import datetime +from typing import Optional, Union + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeType +from extensions.ext_database import db +from models.account import Account +from models.model import EndUser +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) + + +class WorkflowBasedGenerateTaskPipeline: + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + + return workflow_run + + def _workflow_run_success(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Optional[dict] = None) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :return: + """ + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + workflow_run.outputs = outputs + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_run + + def _workflow_run_failed(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + status: WorkflowRunStatus, + error: str) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param status: status + :param error: error message + :return: + """ + workflow_run.status = status.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_run + + def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, + node_id: str, + node_type: NodeType, + node_title: str, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: + """ + Init workflow node execution from workflow run + :param workflow_run: workflow run + :param node_id: node id + :param node_type: node type + :param node_title: node title + :param node_run_index: run index + :param predecessor_node_id: predecessor node id if exists + :return: + """ + # init workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node_id, + index=node_run_index, + node_id=node_id, + node_type=node_type.value, + title=node_title, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + """ + Workflow node execution success + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param inputs: inputs + :param process_data: process data + :param outputs: outputs + :param execution_metadata: execution metadata + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ + if execution_metadata else None + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_node_execution + + def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + error: str) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param error: error message + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_node_execution diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 67ed13d7214645..0ea7744b582943 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,9 +1,11 @@ from enum import Enum -from typing import Any +from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType class QueueEvent(Enum): @@ -16,9 +18,11 @@ class QueueEvent(Enum): MESSAGE_REPLACE = "message_replace" MESSAGE_END = "message_end" WORKFLOW_STARTED = "workflow_started" - WORKFLOW_FINISHED = "workflow_finished" + WORKFLOW_SUCCEEDED = "workflow_succeeded" + WORKFLOW_FAILED = "workflow_failed" NODE_STARTED = "node_started" - NODE_FINISHED = "node_finished" + NODE_SUCCEEDED = "node_succeeded" + NODE_FAILED = "node_failed" RETRIEVER_RESOURCES = "retriever_resources" ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" @@ -96,15 +100,21 @@ class QueueWorkflowStartedEvent(AppQueueEvent): QueueWorkflowStartedEvent entity """ event = QueueEvent.WORKFLOW_STARTED - workflow_run_id: str -class QueueWorkflowFinishedEvent(AppQueueEvent): +class QueueWorkflowSucceededEvent(AppQueueEvent): """ - QueueWorkflowFinishedEvent entity + QueueWorkflowSucceededEvent entity """ - event = QueueEvent.WORKFLOW_FINISHED - workflow_run_id: str + event = QueueEvent.WORKFLOW_SUCCEEDED + + +class QueueWorkflowFailedEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + event = QueueEvent.WORKFLOW_FAILED + error: str class QueueNodeStartedEvent(AppQueueEvent): @@ -112,17 +122,45 @@ class QueueNodeStartedEvent(AppQueueEvent): QueueNodeStartedEvent entity """ event = QueueEvent.NODE_STARTED - workflow_node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + node_run_index: int = 1 + predecessor_node_id: Optional[str] = None -class QueueNodeFinishedEvent(AppQueueEvent): + +class QueueNodeSucceededEvent(AppQueueEvent): """ - QueueNodeFinishedEvent entity + QueueNodeSucceededEvent entity """ - event = QueueEvent.NODE_FINISHED - workflow_node_execution_id: str + event = QueueEvent.NODE_SUCCEEDED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + execution_metadata: Optional[dict] = None + + error: Optional[str] = None + + +class QueueNodeFailedEvent(AppQueueEvent): + """ + QueueNodeFailedEvent entity + """ + event = QueueEvent.NODE_FAILED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + + error: str + - class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 3866bf2c1518eb..cf2915ed8646d4 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,34 +1,63 @@ from abc import ABC, abstractmethod +from typing import Optional -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType class BaseWorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_started(self) -> None: """ Workflow run started """ raise NotImplementedError @abstractmethod - def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_succeeded(self) -> None: """ - Workflow run finished + Workflow run succeeded """ raise NotImplementedError @abstractmethod - def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: """ Workflow node execute started """ raise NotImplementedError @abstractmethod - def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: """ - Workflow node execute finished + Workflow node execute succeeded + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str) -> None: + """ + Workflow node execute failed """ raise NotImplementedError @@ -38,4 +67,3 @@ def on_node_text_chunk(self, node_id: str, text: str) -> None: Publish text chunk """ raise NotImplementedError - diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 8c15cb95cdced3..6c2adfe0fbf9ea 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -1,22 +1,32 @@ +from typing import Optional + +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.nodes.base_node import BaseNode +from models.workflow import Workflow + + +class WorkflowNodeAndResult: + node: BaseNode + result: Optional[NodeRunResult] = None + + def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): + self.node = node + self.result = result class WorkflowRunState: - workflow_run: WorkflowRun + workflow: Workflow start_at: float user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 - workflow_node_executions: list[WorkflowNodeExecution] = [] + workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow_run: WorkflowRun, - start_at: float, - user_inputs: dict, - variable_pool: VariablePool) -> None: - self.workflow_run = workflow_run + def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool): + self.workflow = workflow self.start_at = start_at self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index bc6e4bd8008aa6..971cbe536e25cc 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -43,7 +43,7 @@ def _run(self, variable_pool: Optional[VariablePool] = None, # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) + time.sleep(0.01) # todo sleep 0.01 return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 19dac76631ff44..628df4ac5fb6c6 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,13 +1,11 @@ -import json import time -from datetime import datetime -from typing import Optional, Union +from typing import Optional -from core.model_runtime.utils.encoders import jsonable_encoder +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowRunState +from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -21,18 +19,9 @@ from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode -from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser from models.workflow import ( - CreatedByRole, Workflow, - WorkflowNodeExecution, WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, - WorkflowRunStatus, - WorkflowRunTriggeredFrom, WorkflowType, ) @@ -53,20 +42,6 @@ class WorkflowEngineManager: - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() - - # return workflow - return workflow - def get_default_configs(self) -> list[dict]: """ Get default block configs @@ -100,16 +75,12 @@ def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None return default_config def run_workflow(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], user_inputs: dict, system_inputs: Optional[dict] = None, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks @@ -130,18 +101,13 @@ def run_workflow(self, workflow: Workflow, raise ValueError('edges in workflow graph must be a list') # init workflow run - workflow_run = self._init_workflow_run( - workflow=workflow, - triggered_from=triggered_from, - user=user, - user_inputs=user_inputs, - system_inputs=system_inputs, - callbacks=callbacks - ) + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started() # init workflow run state workflow_run_state = WorkflowRunState( - workflow_run=workflow_run, + workflow=workflow, start_at=time.perf_counter(), user_inputs=user_inputs, variable_pool=VariablePool( @@ -166,7 +132,7 @@ def run_workflow(self, workflow: Workflow, has_entry_node = True # max steps 30 reached - if len(workflow_run_state.workflow_node_executions) > 30: + if len(workflow_run_state.workflow_nodes_and_results) > 30: raise ValueError('Max steps 30 reached.') # or max execution time 10min reached @@ -188,14 +154,14 @@ def run_workflow(self, workflow: Workflow, if not has_entry_node: self._workflow_run_failed( - workflow_run_state=workflow_run_state, error='Start node not found in workflow graph.', callbacks=callbacks ) return + except GenerateTaskStoppedException as e: + return except Exception as e: self._workflow_run_failed( - workflow_run_state=workflow_run_state, error=str(e), callbacks=callbacks ) @@ -203,112 +169,33 @@ def run_workflow(self, workflow: Workflow, # workflow run success self._workflow_run_success( - workflow_run_state=workflow_run_state, callbacks=callbacks ) - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 - new_sequence_number = max_sequence + 1 - - # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id - ) - - db.session.add(workflow_run) - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started(workflow_run) - - return workflow_run - - def _workflow_run_success(self, workflow_run_state: WorkflowRunState, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success - :param workflow_run_state: workflow run state :param callbacks: workflow callbacks :return: """ - workflow_run = workflow_run_state.workflow_run - workflow_run.status = WorkflowRunStatus.SUCCEEDED.value - - # fetch last workflow_node_executions - last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] - if last_workflow_node_execution: - workflow_run.outputs = last_workflow_node_execution.outputs - - workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at - workflow_run.total_tokens = workflow_run_state.total_tokens - workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) - workflow_run.finished_at = datetime.utcnow() - - db.session.commit() if callbacks: for callback in callbacks: - callback.on_workflow_run_finished(workflow_run) + callback.on_workflow_run_succeeded() - return workflow_run - - def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, - error: str, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + def _workflow_run_failed(self, error: str, + callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run failed - :param workflow_run_state: workflow run state :param error: error message :param callbacks: workflow callbacks :return: """ - workflow_run = workflow_run_state.workflow_run - workflow_run.status = WorkflowRunStatus.FAILED.value - workflow_run.error = error - workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at - workflow_run.total_tokens = workflow_run_state.total_tokens - workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) - workflow_run.finished_at = datetime.utcnow() - - db.session.commit() - if callbacks: for callback in callbacks: - callback.on_workflow_run_finished(workflow_run) - - return workflow_run + callback.on_workflow_run_failed( + error=error + ) def _get_next_node(self, graph: dict, predecessor_node: Optional[BaseNode] = None, @@ -384,18 +271,24 @@ def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node: BaseNode, predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - # init workflow node execution - start_at = time.perf_counter() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run_state=workflow_run_state, + callbacks: list[BaseWorkflowCallback] = None) -> None: + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_started( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None + ) + + workflow_nodes_and_result = WorkflowNodeAndResult( node=node, - predecessor_node=predecessor_node, - callbacks=callbacks + result=None ) - # add to workflow node executions - workflow_run_state.workflow_node_executions.append(workflow_node_execution) + # add to workflow_nodes_and_results + workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( @@ -406,24 +299,34 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=start_at, - error=node_run_result.error, - callbacks=callbacks - ) + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_failed( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + error=node_run_result.error + ) + raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") # set end node output if in chat self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) + workflow_nodes_and_result.result = node_run_result + # node run success - self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=start_at, - result=node_run_result, - callbacks=callbacks - ) + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_succeeded( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + inputs=node_run_result.inputs, + process_data=node_run_result.process_data, + outputs=node_run_result.outputs, + execution_metadata=node_run_result.metadata + ) if node_run_result.outputs: for variable_key, variable_value in node_run_result.outputs.items(): @@ -438,105 +341,9 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - return workflow_node_execution - - def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run_state: workflow run state - :param node: current node - :param predecessor_node: predecessor node if exists - :param callbacks: workflow callbacks - :return: - """ - workflow_run = workflow_run_state.workflow_run - - # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - index=len(workflow_run_state.workflow_node_executions) + 1, - node_id=node.node_id, - node_type=node.node_type.value, - title=node.node_data.title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by - ) - - db.session.add(workflow_node_execution) - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started(workflow_node_execution) - - return workflow_node_execution - - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - result: NodeRunResult, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param result: node run result - :param callbacks: workflow callbacks - :return: - """ - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None - workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None - workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \ - if result.metadata else None - workflow_node_execution.finished_at = datetime.utcnow() - - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_finished(workflow_node_execution) - - return workflow_node_execution - - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.finished_at = datetime.utcnow() - - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_finished(workflow_node_execution) - - return workflow_node_execution - def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, node: BaseNode, - node_run_result: NodeRunResult): + node_run_result: NodeRunResult) -> None: """ Set end node output if in chat :param workflow_run_state: workflow run state @@ -544,21 +351,19 @@ def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, :param node_run_result: node run result :return: """ - if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: - workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2] - if workflow_node_execution_before_end: - if workflow_node_execution_before_end.node_type == NodeType.LLM.value: + if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] + if workflow_nodes_and_result_before_end: + if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM.value: if not node_run_result.outputs: node_run_result.outputs = {} - node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text') - elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value: + node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') + elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER.value: if not node_run_result.outputs: node_run_result.outputs = {} - node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer') - - return node_run_result + node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer') def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 833c22cdffaa01..f8bd80a0b1c125 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom @@ -44,10 +45,14 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: if not app_model.workflow_id: return None - workflow_engine_manager = WorkflowEngineManager() - # fetch published workflow by workflow_id - return workflow_engine_manager.get_workflow(app_model, app_model.workflow_id) + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id + ).first() + + return workflow def sync_draft_workflow(self, app_model: App, graph: dict, @@ -201,6 +206,14 @@ def run_draft_workflow(self, app_model: App, return response + def stop_workflow_task(self, task_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom) -> None: + """ + Stop workflow task + """ + AppQueueManager.set_stop_flag(task_id, invoke_from, user.id) + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ Basic mode of chatbot app(expert mode) to workflow From 9b0f83f807d908bc1c7c8ec61fd4c319e8f0f995 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 00:02:44 +0800 Subject: [PATCH 089/160] fix: add max number array length --- api/core/workflow/nodes/code/code_node.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 32f67768500821..e7e8a1c25100ff 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -13,6 +13,7 @@ MAX_DEPTH = 5 MAX_STRING_LENGTH = 1000 MAX_STRING_ARRAY_LENGTH = 30 +MAX_NUMBER_ARRAY_LENGTH = 1000 class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -210,6 +211,11 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) + if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: + raise ValueError( + f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' + ) + transformed_result[output_name] = [ self._check_number( value=value, From e90637f67a89042f0327fce5699b07b90768daa1 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 9 Mar 2024 00:58:12 +0800 Subject: [PATCH 090/160] fix generate bug --- api/core/app/apps/advanced_chat/app_generator.py | 4 ++-- api/core/app/apps/workflow/app_generator.py | 2 -- api/core/workflow/workflow_engine_manager.py | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ed45e2ba8aab97..a0f197ec374465 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -216,5 +216,5 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh else: logger.exception(e) raise e - finally: - db.session.remove() + # finally: + # db.session.remove() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index d3303047caa250..b1a70a83ba42e3 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -168,5 +168,3 @@ def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntit else: logger.exception(e) raise e - finally: - db.session.remove() diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 628df4ac5fb6c6..c5af015e87187f 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -354,12 +354,12 @@ def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] if workflow_nodes_and_result_before_end: - if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM.value: + if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: if not node_run_result.outputs: node_run_result.outputs = {} node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') - elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER.value: + elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER: if not node_run_result.outputs: node_run_result.outputs = {} From 4c5822fb6e2cad159793e85076a94313b1245ec0 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 15:51:02 +0800 Subject: [PATCH 091/160] fix: transform --- api/core/workflow/nodes/code/code_node.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index e7e8a1c25100ff..77bcccab21713a 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -196,8 +196,6 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. value=result[output_name], variable=f'{prefix}.{output_name}' if prefix else output_name ) - - transformed_result[output_name] = result[output_name] elif output_config.type == 'string': # check if string available transformed_result[output_name] = self._check_string( From 2f57d090a1291087512f5f8ecc11c074fe2f71c5 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 9 Mar 2024 19:05:48 +0800 Subject: [PATCH 092/160] refactor pipeline and remove node run run_args --- .../advanced_chat/generate_task_pipeline.py | 47 ++++++++---- .../apps/workflow/generate_task_pipeline.py | 48 +++++++++---- api/core/workflow/entities/variable_pool.py | 5 +- .../workflow/entities/workflow_entities.py | 4 +- api/core/workflow/nodes/base_node.py | 34 ++++++--- api/core/workflow/nodes/code/code_node.py | 45 ++++++------ .../nodes/direct_answer/direct_answer_node.py | 21 +++--- api/core/workflow/nodes/end/end_node.py | 71 ++++++++++--------- api/core/workflow/nodes/llm/llm_node.py | 16 ++++- api/core/workflow/nodes/start/start_node.py | 18 +++-- api/core/workflow/workflow_engine_manager.py | 6 +- 11 files changed, 201 insertions(+), 114 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 18bc9c80080c6f..048b429304c980 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -55,6 +55,19 @@ class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution: WorkflowNodeExecution + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + answer: str = "" metadata: dict = {} usage: LLMUsage @@ -64,8 +77,8 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - current_node_execution: Optional[WorkflowNodeExecution] = None - current_node_execution_start_at: Optional[float] = None + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None class Config: """Configuration for this pydantic object.""" @@ -218,7 +231,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): self._on_node_start(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_started', @@ -237,7 +250,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: @@ -447,15 +460,21 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> None: predecessor_node_id=event.predecessor_node_id ) - self._task_state.current_node_execution = workflow_node_execution - self._task_state.current_node_execution_start_at = time.perf_counter() + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution=workflow_node_execution, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.total_steps += 1 def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, @@ -472,12 +491,14 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven self._task_state.metadata['usage'] = usage_dict else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, error=event.error ) - self._task_state.current_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: if isinstance(event, QueueStopEvent): @@ -504,8 +525,8 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.current_node_execution.outputs - if self._task_state.current_node_execution else None + outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs + if self._task_state.latest_node_execution_info else None ) self._task_state.workflow_run = workflow_run diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 721124c4c54166..26e4769fa61ce1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -41,6 +41,19 @@ class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution: WorkflowNodeExecution + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + answer: str = "" metadata: dict = {} @@ -49,8 +62,8 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - current_node_execution: Optional[WorkflowNodeExecution] = None - current_node_execution_start_at: Optional[float] = None + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None class Config: """Configuration for this pydantic object.""" @@ -179,7 +192,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): self._on_node_start(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_started', @@ -198,7 +211,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_finished', @@ -339,15 +352,22 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> None: predecessor_node_id=event.predecessor_node_id ) - self._task_state.current_node_execution = workflow_node_execution - self._task_state.current_node_execution_start_at = time.perf_counter() + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution=workflow_node_execution, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + self._task_state.total_steps += 1 def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, @@ -359,12 +379,14 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, error=event.error ) - self._task_state.current_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: if isinstance(event, QueueStopEvent): @@ -391,8 +413,8 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.current_node_execution.outputs - if self._task_state.current_node_execution else None + outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs + if self._task_state.latest_node_execution_info else None ) self._task_state.workflow_run = workflow_run diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index e84044dede1982..3868041a8f28ea 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -19,14 +19,17 @@ class ValueType(Enum): class VariablePool: variables_mapping = {} + user_inputs: dict - def __init__(self, system_variables: dict[SystemVariable, Any]) -> None: + def __init__(self, system_variables: dict[SystemVariable, Any], + user_inputs: dict) -> None: # system variables # for example: # { # 'query': 'abc', # 'files': [] # } + self.user_inputs = user_inputs for system_variable, value in system_variables.items(): self.append_variable('sys', [system_variable.value], value) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 6c2adfe0fbf9ea..768ad6a1303a97 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -18,15 +18,13 @@ def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): class WorkflowRunState: workflow: Workflow start_at: float - user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool): + def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): self.workflow = workflow self.start_at = start_at - self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6720017d9f0814..3f2e806433e289 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -28,31 +28,23 @@ def __init__(self, config: dict, self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ raise NotImplementedError - def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node entry :param variable_pool: variable pool - :param run_args: run args :return: """ - if variable_pool is None and run_args is None: - raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - try: result = self._run( - variable_pool=variable_pool, - run_args=run_args + variable_pool=variable_pool ) except Exception as e: # process unhandled exception @@ -77,6 +69,26 @@ def publish_text_chunk(self, text: str) -> None: text=text ) + @classmethod + def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict: + """ + Extract variable selector to variable mapping + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + return cls._extract_variable_selector_to_variable_mapping(node_data) + + @classmethod + @abstractmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + raise NotImplementedError + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 77bcccab21713a..a65edafbad23c9 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,6 @@ from typing import Optional, Union, cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -15,6 +16,7 @@ MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 + class CodeNode(BaseNode): _node_data_cls = CodeNodeData node_type = NodeType.CODE @@ -78,21 +80,15 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } } - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run code :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) + node_data = cast(self._node_data_cls, node_data) - # SINGLE DEBUG NOT IMPLEMENTED YET - if variable_pool is None and run_args: - raise ValueError("Not support single step debug.") - # Get code language code_language = node_data.code_language code = node_data.code @@ -134,7 +130,6 @@ def _check_string(self, value: str, variable: str) -> str: Check string :param value: value :param variable: variable - :param max_length: max length :return: """ if not isinstance(value, str): @@ -142,9 +137,9 @@ def _check_string(self, value: str, variable: str) -> str: if len(value) > MAX_STRING_LENGTH: raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') - + return value.replace('\x00', '') - + def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ Check number @@ -157,13 +152,13 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f if value > MAX_NUMBER or value < MIN_NUMBER: raise ValueError(f'{variable} in input form is out of range.') - + if isinstance(value, float): value = round(value, MAX_PRECISION) return value - def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], prefix: str = '', depth: int = 1) -> dict: """ @@ -174,7 +169,7 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. """ if depth > MAX_DEPTH: raise ValueError("Depth limit reached, object too deep.") - + transformed_result = {} for output_name, output_config in output_schema.items(): if output_config.type == 'object': @@ -183,7 +178,7 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. raise ValueError( f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' ) - + transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, @@ -208,7 +203,7 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. raise ValueError( f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) - + if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: raise ValueError( f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' @@ -227,12 +222,12 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. raise ValueError( f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) - + if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: raise ValueError( f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' ) - + transformed_result[output_name] = [ self._check_string( value=value, @@ -242,5 +237,15 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. ] else: raise ValueError(f'Output type {output_config.type} is not supported.') - - return transformed_result \ No newline at end of file + + return transformed_result + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + # TODO extract variable selector to variable mapping for single step debugging + return {} diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 971cbe536e25cc..9193bab9ee5463 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -1,7 +1,8 @@ import time -from typing import Optional, cast +from typing import cast from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode @@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode): _node_data_cls = DirectAnswerNodeData node_type = NodeType.DIRECT_ANSWER - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - if variable_pool is None and run_args: - raise ValueError("Not support single step debug.") - variable_values = {} for variable_selector in node_data.variables: value = variable_pool.get_variable_value( @@ -43,7 +39,7 @@ def _run(self, variable_pool: Optional[VariablePool] = None, # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) # todo sleep 0.01 + time.sleep(0.01) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -52,3 +48,12 @@ def _run(self, variable_pool: Optional[VariablePool] = None, "answer": answer } ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 62429e3ac284fe..65b0b86aa0314f 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,6 @@ -from typing import Optional, cast +from typing import cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode @@ -11,50 +12,54 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData node_type = NodeType.END - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) outputs_config = node_data.outputs - if variable_pool is not None: - outputs = None - if outputs_config: - if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: - plain_text_selector = outputs_config.plain_text_selector - if plain_text_selector: - outputs = { - 'text': variable_pool.get_variable_value( - variable_selector=plain_text_selector, - target_value_type=ValueType.STRING - ) - } - else: - outputs = { - 'text': '' - } - elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: - structured_variables = outputs_config.structured_variables - if structured_variables: - outputs = {} - for variable_selector in structured_variables: - variable_value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - outputs[variable_selector.variable] = variable_value - else: - outputs = {} - else: - raise ValueError("Not support single step debug.") + outputs = None + if outputs_config: + if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: + plain_text_selector = outputs_config.plain_text_selector + if plain_text_selector: + outputs = { + 'text': variable_pool.get_variable_value( + variable_selector=plain_text_selector, + target_value_type=ValueType.STRING + ) + } + else: + outputs = { + 'text': '' + } + elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: + structured_variables = outputs_config.structured_variables + if structured_variables: + outputs = {} + for variable_selector in structured_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + else: + outputs = {} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index e3ae9fc00f5f18..90a7755b85de81 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,5 +1,6 @@ from typing import Optional, cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -10,12 +11,10 @@ class LLMNode(BaseNode): _node_data_cls = LLMNodeData node_type = NodeType.LLM - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data @@ -23,6 +22,17 @@ def _run(self, variable_pool: Optional[VariablePool] = None, pass + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + # TODO extract variable selector to variable mapping for single step debugging + return {} + + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index ce04031b046fe6..2321e04bd4256e 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,6 +1,7 @@ -from typing import Optional, cast +from typing import cast from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -12,12 +13,10 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData node_type = NodeType.START - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data @@ -25,7 +24,7 @@ def _run(self, variable_pool: Optional[VariablePool] = None, variables = node_data.variables # Get cleaned inputs - cleaned_inputs = self._get_cleaned_inputs(variables, run_args) + cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -68,3 +67,12 @@ def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict filtered_inputs[variable] = value.replace('\x00', '') if value else None return filtered_inputs + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index c5af015e87187f..0b96717de7f517 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -109,9 +109,9 @@ def run_workflow(self, workflow: Workflow, workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), - user_inputs=user_inputs, variable_pool=VariablePool( system_variables=system_inputs, + user_inputs=user_inputs ) ) @@ -292,9 +292,7 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool, - run_args=workflow_run_state.user_inputs - if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node + variable_pool=workflow_run_state.variable_pool ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: From a0fd731170a1fc6a890d7a9618b6c25e164b72c4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 19:45:57 +0800 Subject: [PATCH 093/160] feat: mapping variables --- api/core/workflow/nodes/code/code_node.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index a65edafbad23c9..170f2b9cd8eb7d 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -87,7 +87,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: :return: """ node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data: CodeNodeData = cast(self._node_data_cls, node_data) # Get code language code_language = node_data.code_language @@ -241,11 +241,13 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[list[str], str]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - # TODO extract variable selector to variable mapping for single step debugging - return {} + + return { + variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + } \ No newline at end of file From 193bcce236176abc939693f80f3584d2fb1f36eb Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 19:59:47 +0800 Subject: [PATCH 094/160] feat: http request --- api/core/workflow/nodes/code/code_node.py | 1 - .../workflow/nodes/http_request/entities.py | 31 +++++++++++++++++++ .../nodes/http_request/http_request_node.py | 20 ++++++++++-- 3 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/http_request/entities.py diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 170f2b9cd8eb7d..3d3c475d067ec5 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,6 +1,5 @@ from typing import Optional, Union, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py new file mode 100644 index 00000000000000..8610e88e551721 --- /dev/null +++ b/api/core/workflow/nodes/http_request/entities.py @@ -0,0 +1,31 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class HttpRequestNodeData(BaseNodeData): + """ + Code Node Data. + """ + class Authorization(BaseModel): + class Config(BaseModel): + type: Literal[None, 'basic', 'bearer', 'custom'] + api_key: Union[None, str] + header: Union[None, str] + + type: Literal['no-auth', 'api-key'] + + class Body(BaseModel): + type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw'] + data: Union[None, str] + + variables: list[VariableSelector] + method: Literal['get', 'post', 'put', 'patch', 'delete'] + url: str + authorization: Authorization + headers: str + params: str + \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 5be25a9834d7af..d0fa29646f1b16 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,5 +1,21 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode - +from core.workflow.nodes.http_request.entities import HttpRequestNodeData class HttpRequestNode(BaseNode): - pass + _node_data_cls = HttpRequestNodeData + node_type = NodeType.HTTP_REQUEST + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + pass \ No newline at end of file From 614bc2e075eee1ab938e363a9168d776002e4dc4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:19:48 +0800 Subject: [PATCH 095/160] feat: http reqeust --- api/core/helper/ssrf_proxy.py | 4 + .../workflow/nodes/http_request/entities.py | 5 +- .../nodes/http_request/http_executor.py | 240 ++++++++++++++++++ .../nodes/http_request/http_request_node.py | 39 ++- 4 files changed, 285 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/http_request/http_executor.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0bfe763fac66bd..c44d4717e6da33 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -38,6 +38,10 @@ def patch(url, *args, **kwargs): return _patch(url=url, *args, proxies=httpx_proxies, **kwargs) def delete(url, *args, **kwargs): + if 'follow_redirects' in kwargs: + if kwargs['follow_redirects']: + kwargs['allow_redirects'] = kwargs['follow_redirects'] + kwargs.pop('follow_redirects') return _delete(url=url, *args, proxies=requests_proxies, **kwargs) def head(url, *args, **kwargs): diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 8610e88e551721..1e906cbaa4b777 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -17,9 +17,10 @@ class Config(BaseModel): header: Union[None, str] type: Literal['no-auth', 'api-key'] + config: Config class Body(BaseModel): - type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw'] + type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] data: Union[None, str] variables: list[VariableSelector] @@ -28,4 +29,4 @@ class Body(BaseModel): authorization: Authorization headers: str params: str - \ No newline at end of file + body: Body \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py new file mode 100644 index 00000000000000..4b13e92e0c2b37 --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -0,0 +1,240 @@ +from copy import deepcopy +from typing import Any, Union +from urllib.parse import urlencode + +import httpx +import re +import requests +import core.helper.ssrf_proxy as ssrf_proxy +from core.workflow.nodes.http_request.entities import HttpRequestNodeData + +HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60) + +class HttpExecutorResponse: + status_code: int + headers: dict[str, str] + body: str + + def __init__(self, status_code: int, headers: dict[str, str], body: str): + """ + init + """ + self.status_code = status_code + self.headers = headers + self.body = body + +class HttpExecutor: + server_url: str + method: str + authorization: HttpRequestNodeData.Authorization + params: dict[str, Any] + headers: dict[str, Any] + body: Union[None, str] + files: Union[None, dict[str, Any]] + + def __init__(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + """ + init + """ + self.server_url = node_data.url + self.method = node_data.method + self.authorization = node_data.authorization + self.params = {} + self.headers = {} + self.body = None + + # init template + self._init_template(node_data, variables) + + def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + """ + init template + """ + # extract all template in url + url_template = re.findall(r'{{(.*?)}}', node_data.url) or [] + url_template = list(set(url_template)) + original_url = node_data.url + for url in url_template: + if not url: + continue + + original_url = original_url.replace(f'{{{{{url}}}}}', str(variables.get(url, ''))) + + self.server_url = original_url + + # extract all template in params + param_template = re.findall(r'{{(.*?)}}', node_data.params) or [] + param_template = list(set(param_template)) + original_params = node_data.params + for param in param_template: + if not param: + continue + + original_params = original_params.replace(f'{{{{{param}}}}}', str(variables.get(param, ''))) + + # fill in params + kv_paris = original_params.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) != 2: + raise ValueError(f'Invalid params {kv}') + + k, v = kv + self.params[k] = v + + # extract all template in headers + header_template = re.findall(r'{{(.*?)}}', node_data.headers) or [] + header_template = list(set(header_template)) + original_headers = node_data.headers + for header in header_template: + if not header: + continue + + original_headers = original_headers.replace(f'{{{{{header}}}}}', str(variables.get(header, ''))) + + # fill in headers + kv_paris = original_headers.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) != 2: + raise ValueError(f'Invalid headers {kv}') + + k, v = kv + self.headers[k] = v + + # extract all template in body + body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] + body_template = list(set(body_template)) + original_body = node_data.body.data or '' + for body in body_template: + if not body: + continue + + original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) + + if node_data.body.type == 'json': + self.headers['Content-Type'] = 'application/json' + elif node_data.body.type == 'x-www-form-urlencoded': + self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + # elif node_data.body.type == 'form-data': + # self.headers['Content-Type'] = 'multipart/form-data' + + if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + body = {} + kv_paris = original_body.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) != 2: + raise ValueError(f'Invalid body {kv}') + body[kv[0]] = kv[1] + + if node_data.body.type == 'form-data': + self.files = { + k: ('', v) for k, v in body.items() + } + else: + self.body = urlencode(body) + else: + self.body = original_body + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.authorization) + headers = deepcopy(self.headers) or [] + if self.authorization.type == 'api-key': + if self.authorization.config.api_key is None: + raise ValueError('api_key is required') + + if not self.authorization.config.header: + authorization.config.header = 'Authorization' + + if self.authorization.config.type == 'bearer': + headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' + elif self.authorization.config.type == 'basic': + headers[authorization.config.header] = f'Basic {authorization.config.api_key}' + elif self.authorization.config.type == 'custom': + headers[authorization.config.header] = authorization.config.api_key + + return headers + + def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse: + """ + validate the response + """ + if isinstance(response, httpx.Response): + # get key-value pairs headers + headers = {} + for k, v in response.headers.items(): + headers[k] = v + + return HttpExecutorResponse(response.status_code, headers, response.text) + elif isinstance(response, requests.Response): + # get key-value pairs headers + headers = {} + for k, v in response.headers.items(): + headers[k] = v + + return HttpExecutorResponse(response.status_code, headers, response.text) + else: + raise ValueError(f'Invalid response type {type(response)}') + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + # do http request + kwargs = { + 'url': self.server_url, + 'headers': headers, + 'params': self.params, + 'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT, + 'follow_redirects': True + } + + if self.method == 'get': + response = ssrf_proxy.get(**kwargs) + elif self.method == 'post': + response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs) + elif self.method == 'put': + response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs) + elif self.method == 'delete': + response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs) + elif self.method == 'patch': + response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs) + elif self.method == 'head': + response = ssrf_proxy.head(**kwargs) + elif self.method == 'options': + response = ssrf_proxy.options(**kwargs) + else: + raise ValueError(f'Invalid http method {self.method}') + + return response + + def invoke(self) -> HttpExecutorResponse: + """ + invoke http request + """ + # assemble headers + headers = self._assembling_headers() + + # do http request + response = self._do_http_request(headers) + + # validate response + return self._validate_and_parse_response(response) + + def to_raw_request(self) -> str: + """ + convert to raw request + """ + server_url = self.server_url + if self.params: + server_url += f'?{urlencode(self.params)}' + + raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + for k, v in self.headers.items(): + raw_request += f'{k}: {v}\n' + + raw_request += '\n' + raw_request += self.body or '' + + return raw_request \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index d0fa29646f1b16..f55f48c4af75d4 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,15 +1,52 @@ +from os import error +from typing import cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.nodes.http_request.http_executor import HttpExecutor +from models.workflow import WorkflowNodeExecutionStatus + class HttpRequestNode(BaseNode): _node_data_cls = HttpRequestNodeData node_type = NodeType.HTTP_REQUEST def _run(self, variable_pool: VariablePool) -> NodeRunResult: - pass + node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value(variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # init http executor + try: + http_executor = HttpExecutor(node_data=node_data, variables=variables) + # invoke http executor + + response = http_executor.invoke() + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + process_data=http_executor.to_raw_request() + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs={ + 'status_code': response.status_code, + 'body': response, + 'headers': response.headers + }, + process_data=http_executor.to_raw_request() + ) + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: From 3d5f9b5a1eb9d2921339ac3fe96b0dd6426170af Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:26:19 +0800 Subject: [PATCH 096/160] fix: missing _extract_variable_selector_to_variable_mapping --- api/core/workflow/nodes/http_request/http_executor.py | 3 ++- api/core/workflow/nodes/http_request/http_request_node.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 4b13e92e0c2b37..82d879a89ccb4e 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -1,10 +1,11 @@ +import re from copy import deepcopy from typing import Any, Union from urllib.parse import urlencode import httpx -import re import requests + import core.helper.ssrf_proxy as ssrf_proxy from core.workflow.nodes.http_request.entities import HttpRequestNodeData diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index f55f48c4af75d4..e3e864b6b02f04 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,5 +1,5 @@ -from os import error from typing import cast + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool @@ -49,10 +49,12 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[list[str], str]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - pass \ No newline at end of file + return { + variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + } \ No newline at end of file From 2895c3bc8c997efcaef70f6008917e38c4366d22 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:49:53 +0800 Subject: [PATCH 097/160] feat: template transform --- .../code_executor}/code_executor.py | 15 +++-- .../code_executor/javascript_transformer.py | 1 + .../helper/code_executor/jina2_transformer.py | 1 + .../code_executor/python_transformer.py} | 4 +- .../code_executor/template_transformer.py | 24 ++++++++ api/core/workflow/nodes/code/code_node.py | 2 +- api/core/workflow/nodes/code/entities.py | 2 +- .../nodes/http_request/http_request_node.py | 1 - .../nodes/template_transform/entities.py | 14 +++++ .../template_transform_node.py | 59 ++++++++++++++++++- 10 files changed, 114 insertions(+), 9 deletions(-) rename api/core/{workflow/nodes/code => helper/code_executor}/code_executor.py (75%) create mode 100644 api/core/helper/code_executor/javascript_transformer.py create mode 100644 api/core/helper/code_executor/jina2_transformer.py rename api/core/{workflow/nodes/code/python_template.py => helper/code_executor/python_transformer.py} (90%) create mode 100644 api/core/helper/code_executor/template_transformer.py create mode 100644 api/core/workflow/nodes/template_transform/entities.py diff --git a/api/core/workflow/nodes/code/code_executor.py b/api/core/helper/code_executor/code_executor.py similarity index 75% rename from api/core/workflow/nodes/code/code_executor.py rename to api/core/helper/code_executor/code_executor.py index 058ee83d4646da..f1bc4fbdafde4b 100644 --- a/api/core/workflow/nodes/code/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,10 +1,11 @@ from os import environ +from typing import Literal from httpx import post from pydantic import BaseModel from yarl import URL -from core.workflow.nodes.code.python_template import PythonTemplateTransformer +from core.helper.code_executor.python_transformer import PythonTemplateTransformer # Code Executor CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '') @@ -24,7 +25,7 @@ class Data(BaseModel): class CodeExecutor: @classmethod - def execute_code(cls, language: str, code: str, inputs: dict) -> dict: + def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -32,7 +33,13 @@ def execute_code(cls, language: str, code: str, inputs: dict) -> dict: :param inputs: inputs :return: """ - runner = PythonTemplateTransformer.transform_caller(code, inputs) + template_transformer = None + if language == 'python3': + template_transformer = PythonTemplateTransformer + else: + raise CodeExecutionException('Unsupported language') + + runner = template_transformer.transform_caller(code, inputs) url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' headers = { @@ -67,4 +74,4 @@ def execute_code(cls, language: str, code: str, inputs: dict) -> dict: if response.data.stderr: raise CodeExecutionException(response.data.stderr) - return PythonTemplateTransformer.transform_response(response.data.stdout) \ No newline at end of file + return template_transformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py new file mode 100644 index 00000000000000..f87f5c14cbbd7d --- /dev/null +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py new file mode 100644 index 00000000000000..f87f5c14cbbd7d --- /dev/null +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file diff --git a/api/core/workflow/nodes/code/python_template.py b/api/core/helper/code_executor/python_transformer.py similarity index 90% rename from api/core/workflow/nodes/code/python_template.py rename to api/core/helper/code_executor/python_transformer.py index 03dfee36f3e488..7b862649d8a211 100644 --- a/api/core/workflow/nodes/code/python_template.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -1,6 +1,8 @@ import json import re +from core.helper.code_executor.template_transformer import TemplateTransformer + PYTHON_RUNNER = """# declare main function here {{code}} @@ -19,7 +21,7 @@ """ -class PythonTemplateTransformer: +class PythonTemplateTransformer(TemplateTransformer): @classmethod def transform_caller(cls, code: str, inputs: dict) -> str: """ diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py new file mode 100644 index 00000000000000..5505df87493c02 --- /dev/null +++ b/api/core/helper/code_executor/template_transformer.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + + +class TemplateTransformer(ABC): + @classmethod + @abstractmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + pass + + @classmethod + @abstractmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + pass \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 3d3c475d067ec5..7d3162d9830d64 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,9 +1,9 @@ from typing import Optional, Union, cast +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.code.code_executor import CodeExecutionException, CodeExecutor from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 2212d77e2d4d88..6a18d181cb2ecc 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -16,6 +16,6 @@ class Output(BaseModel): variables: list[VariableSelector] answer: str - code_language: str + code_language: Literal['python3', 'javascript'] code: str outputs: dict[str, Output] diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index e3e864b6b02f04..4ee76deb83eaa8 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,6 +1,5 @@ from typing import cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py new file mode 100644 index 00000000000000..2d3d35b84c384f --- /dev/null +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -0,0 +1,14 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class TemplateTransformNodeData(BaseNodeData): + """ + Code Node Data. + """ + variables: list[VariableSelector] + template: str \ No newline at end of file diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2bf26e307eaaad..3fb880d926e1a4 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,9 +1,18 @@ -from typing import Optional +from typing import Optional, cast +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from models.workflow import WorkflowNodeExecutionStatus class TemplateTransformNode(BaseNode): + _node_data_cls = TemplateTransformNodeData + _node_type = NodeType.TEMPLATE_TRANSFORM + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -23,3 +32,51 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: "template": "{{ arg1 }}" } } + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + """ + node_data = self.node_data + node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data) + + # Get variables + variables = {} + for variable_selector in node_data.variables: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + variables[variable] = value + + # Run code + try: + result = CodeExecutor.execute_code( + language='jina2', + code=node_data.template, + inputs=variables + ) + except CodeExecutionException as e: + return NodeRunResult( + inputs=variables, + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=result['result'] + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return { + variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + } \ No newline at end of file From 51f6ab49cf15bc1edbaa68c29288057cda5c1a99 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:50:11 +0800 Subject: [PATCH 098/160] fix: linter --- api/core/workflow/nodes/template_transform/entities.py | 2 -- .../nodes/template_transform/template_transform_node.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index 2d3d35b84c384f..d9099a8118498e 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,6 +1,4 @@ -from typing import Literal, Union -from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 3fb880d926e1a4..724b84495c53e9 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,9 +1,8 @@ from typing import Optional, cast + from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool - from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus From de3978fdbb7a0b41883afd493af4abee718f651f Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 13:19:17 +0800 Subject: [PATCH 099/160] optimize db connections --- api/config.py | 2 ++ api/core/app/apps/advanced_chat/app_generator.py | 13 ++++++++++--- .../apps/advanced_chat/generate_task_pipeline.py | 2 ++ api/core/app/apps/message_based_app_generator.py | 8 ++++++++ .../app/apps/workflow/generate_task_pipeline.py | 2 ++ .../apps/workflow_based_generate_task_pipeline.py | 11 +++++++++++ api/core/workflow/workflow_engine_manager.py | 5 +++++ 7 files changed, 40 insertions(+), 3 deletions(-) diff --git a/api/config.py b/api/config.py index a6bc731b820a4e..a4ec6fcef95731 100644 --- a/api/config.py +++ b/api/config.py @@ -27,6 +27,7 @@ 'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'DEPLOY_ENV': 'PRODUCTION', 'SQLALCHEMY_POOL_SIZE': 30, + 'SQLALCHEMY_MAX_OVERFLOW': 10, 'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_ECHO': 'False', 'SENTRY_TRACES_SAMPLE_RATE': 1.0, @@ -148,6 +149,7 @@ def __init__(self): self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}" self.SQLALCHEMY_ENGINE_OPTIONS = { 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), + 'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')), 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) } diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a0f197ec374465..50b561dfe63d1c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -95,6 +95,12 @@ def generate(self, app_model: App, extras=extras ) + workflow = db.session.query(Workflow).filter(Workflow.id == workflow.id).first() + user = (db.session.query(Account).filter(Account.id == user.id).first() + if isinstance(user, Account) + else db.session.query(EndUser).filter(EndUser.id == user.id).first()) + db.session.close() + # init generate records ( conversation, @@ -153,6 +159,8 @@ def _generate_worker(self, flask_app: Flask, conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + db.session.close() + # chatbot app runner = AdvancedChatAppRunner() runner.run( @@ -177,7 +185,7 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, @@ -198,6 +206,7 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh :return: """ # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, workflow=workflow, @@ -216,5 +225,3 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh else: logger.exception(e) raise e - # finally: - # db.session.remove() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 048b429304c980..6991b8704af183 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -122,6 +122,8 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream + db.session.close() + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 0e76c96ff7fc6b..be7538ea07b77d 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -177,6 +177,9 @@ def _init_generate_records(self, db.session.add(conversation) db.session.commit() + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + db.session.close() + message = Message( app_id=app_config.app_id, model_provider=model_provider, @@ -204,6 +207,9 @@ def _init_generate_records(self, db.session.add(message) db.session.commit() + message = db.session.query(Message).filter(Message.id == message.id).first() + db.session.close() + for file in application_generate_entity.files: message_file = MessageFile( message_id=message.id, @@ -218,6 +224,8 @@ def _init_generate_records(self, db.session.add(message_file) db.session.commit() + db.session.close() + return conversation, message def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 26e4769fa61ce1..2c2f941beefcb2 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -99,6 +99,8 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream + db.session.close() + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index 3e9a7b9e1fc98e..640159bae32899 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -61,6 +61,9 @@ def _init_workflow_run(self, workflow: Workflow, db.session.add(workflow_run) db.session.commit() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run.id).first() + db.session.close() + return workflow_run def _workflow_run_success(self, workflow_run: WorkflowRun, @@ -85,6 +88,7 @@ def _workflow_run_success(self, workflow_run: WorkflowRun, workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_run @@ -112,6 +116,7 @@ def _workflow_run_failed(self, workflow_run: WorkflowRun, workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_run @@ -151,6 +156,10 @@ def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, db.session.add(workflow_node_execution) db.session.commit() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution.id).first()) + db.session.close() + return workflow_node_execution def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, @@ -179,6 +188,7 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_node_execution @@ -198,5 +208,6 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_node_execution diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0b96717de7f517..50f79df1f069e9 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -19,6 +19,7 @@ from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode +from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, @@ -282,6 +283,8 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, predecessor_node_id=predecessor_node.node_id if predecessor_node else None ) + db.session.close() + workflow_nodes_and_result = WorkflowNodeAndResult( node=node, result=None @@ -339,6 +342,8 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) + db.session.close() + def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, node: BaseNode, node_run_result: NodeRunResult) -> None: From 7e4daf131e7da3ab7eb081020edc01260f0d97b6 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 14:49:52 +0800 Subject: [PATCH 100/160] optimize db connections --- api/core/app/apps/advanced_chat/app_generator.py | 7 ------- .../app/apps/advanced_chat/generate_task_pipeline.py | 6 ++++-- api/core/app/apps/message_based_app_generator.py | 10 ++-------- api/core/app/apps/workflow/generate_task_pipeline.py | 6 ++++-- .../app/apps/workflow_based_generate_task_pipeline.py | 7 ++----- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 50b561dfe63d1c..b1bc8399660a59 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -95,12 +95,6 @@ def generate(self, app_model: App, extras=extras ) - workflow = db.session.query(Workflow).filter(Workflow.id == workflow.id).first() - user = (db.session.query(Account).filter(Account.id == user.id).first() - if isinstance(user, Account) - else db.session.query(EndUser).filter(EndUser.id == user.id).first()) - db.session.close() - # init generate records ( conversation, @@ -206,7 +200,6 @@ def _handle_advanced_chat_response(self, application_generate_entity: AdvancedCh :return: """ # init generate task pipeline - generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, workflow=workflow, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6991b8704af183..88ac5fd2357ce7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -122,13 +122,15 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream - db.session.close() - def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ + db.session.refresh(self._workflow) + db.session.refresh(self._user) + db.session.close() + if self._stream: return self._process_stream_response() else: diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index be7538ea07b77d..5d0f4bc63a6549 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -176,9 +176,7 @@ def _init_generate_records(self, db.session.add(conversation) db.session.commit() - - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - db.session.close() + db.session.refresh(conversation) message = Message( app_id=app_config.app_id, @@ -206,9 +204,7 @@ def _init_generate_records(self, db.session.add(message) db.session.commit() - - message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.close() + db.session.refresh(message) for file in application_generate_entity.files: message_file = MessageFile( @@ -224,8 +220,6 @@ def _init_generate_records(self, db.session.add(message_file) db.session.commit() - db.session.close() - return conversation, message def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2c2f941beefcb2..9bd20f978519a9 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -99,13 +99,15 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream - db.session.close() - def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ + db.session.refresh(self._workflow) + db.session.refresh(self._user) + db.session.close() + if self._stream: return self._process_stream_response() else: diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index 640159bae32899..d29cee3ac4a086 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -60,8 +60,7 @@ def _init_workflow_run(self, workflow: Workflow, db.session.add(workflow_run) db.session.commit() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run.id).first() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -155,9 +154,7 @@ def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, db.session.add(workflow_node_execution) db.session.commit() - - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution.id).first()) + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution From 8b832097de7316238a0713c05eca839a468863b0 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 16:29:55 +0800 Subject: [PATCH 101/160] optimize db connections --- api/controllers/console/app/app.py | 78 ++++----- api/controllers/console/app/model_config.py | 149 +++++++++--------- .../easy_ui_based_app/dataset/manager.py | 3 +- .../app/apps/advanced_chat/app_generator.py | 2 - api/core/app/apps/advanced_chat/app_runner.py | 2 + api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 4 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 2 + .../app/apps/message_based_app_generator.py | 2 - api/core/app/apps/workflow/app_runner.py | 2 + api/core/tools/tool_manager.py | 2 +- api/models/model.py | 2 +- 14 files changed, 131 insertions(+), 123 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 66bcbccefe4f00..94406030697458 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,3 +1,5 @@ +import json + from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden, BadRequest @@ -6,6 +8,8 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from core.agent.entities import AgentToolEntity +from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, @@ -14,10 +18,8 @@ from libs.login import login_required from services.app_service import AppService from models.model import App, AppModelConfig, AppMode -from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager -from core.entities.application_entities import AgentToolEntity ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] @@ -108,41 +110,43 @@ class AppApi(Resource): def get(self, app_model): """Get app detail""" # get original app model config - model_config: AppModelConfig = app_model.app_model_config - agent_mode = model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - masked_parameter = {} - - # override tool parameters - tool['tool_parameters'] = masked_parameter - except Exception as e: - pass - - # override agent mode - model_config.agent_mode = json.dumps(agent_mode) + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + model_config: AppModelConfig = app_model.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + agent_tool_entity = AgentToolEntity(**tool) + # get tool + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} + + # override tool parameters + tool['tool_parameters'] = masked_parameter + except Exception as e: + pass + + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) + db.session.commit() return app_model diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 1301d12da4e560..41b7151ba65013 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -8,7 +8,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AgentToolEntity +from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated @@ -38,90 +38,91 @@ def post(self, app_model): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app.app_model_config_id - ).first() - agent_mode = original_app_model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - parameter_map = {} - masked_parameter_map = {} - tool_map = {} - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - except Exception as e: - continue - - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - parameters = {} - masked_parameter = {} - - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - masked_parameter_map[key] = masked_parameter - parameter_map[key] = parameters - tool_map[key] = tool_runtime - - # encrypt agent tool parameters if it's secret-input - agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: - agent_tool_entity = AgentToolEntity(**tool) - - # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - if key in tool_map: - tool_runtime = tool_map[key] - else: + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + # get original app model config + original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( + AppModelConfig.id == app_model.app_model_config_id + ).first() + agent_mode = original_app_model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + parameter_map = {} + masked_parameter_map = {} + tool_map = {} + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + + agent_tool_entity = AgentToolEntity(**tool) + # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, agent_callback=None ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) except Exception as e: continue - - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - manager.delete_tool_parameters_cache() - - # override parameters if it equals to masked parameters - if agent_tool_entity.tool_parameters: - if key not in masked_parameter_map: - continue - if agent_tool_entity.tool_parameters == masked_parameter_map[key]: - agent_tool_entity.tool_parameters = parameter_map[key] + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + parameters = {} + masked_parameter = {} + + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + masked_parameter_map[key] = masked_parameter + parameter_map[key] = parameters + tool_map[key] = tool_runtime + + # encrypt agent tool parameters if it's secret-input + agent_mode = new_app_model_config.agent_mode_dict + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + + # get tool + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + if key in tool_map: + tool_runtime = tool_map[key] + else: + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + except Exception as e: + continue + + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() + + # override parameters if it equals to masked parameters + if agent_tool_entity.tool_parameters: + if key not in masked_parameter_map: + continue + + if agent_tool_entity.tool_parameters == masked_parameter_map[key]: + agent_tool_entity.tool_parameters = parameter_map[key] - # encrypt parameters - if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + # encrypt parameters + if agent_tool_entity.tool_parameters: + tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - # update app model config - new_app_model_config.agent_mode = json.dumps(agent_mode) + # update app model config + new_app_model_config.agent_mode = json.dumps(agent_mode) db.session.add(new_app_model_config) db.session.flush() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 4c08f62d27217b..c10aa98dbacf9e 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -123,7 +123,8 @@ def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: di if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets") + need_manual_query_datasets = (config.get("dataset_configs") + and config["dataset_configs"].get("datasets", {}).get("datasets")) if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index b1bc8399660a59..1a33a3230bd773 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -153,8 +153,6 @@ def _generate_worker(self, flask_app: Flask, conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - db.session.close() - # chatbot app runner = AdvancedChatAppRunner() runner.run( diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 3279e00355e915..c42620b92f9266 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -72,6 +72,8 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, ): return + db.session.close() + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 700a340c969980..cc9b0785f56106 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -193,4 +193,4 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2e142c63f1ff47..0dc8a1e2184abe 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -201,8 +201,8 @@ def run(self, application_generate_entity: AgentChatAppGenerateEntity, if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - db.session.refresh(conversation) - db.session.refresh(message) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + message = db.session.query(Message).filter(Message.id == message.id).first() db.session.close() # start agent runner diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 317d045c043083..58287ba6587d09 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -193,4 +193,4 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index b948938aac24ad..fb6246972075cf 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -182,7 +182,7 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def generate_more_like_this(self, app_model: App, message_id: str, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 04adf77be5b2c4..649d73d96180fa 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -160,6 +160,8 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, model=application_generate_entity.model_config.model ) + db.session.close() + invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=application_generate_entity.model_config.parameters, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 5d0f4bc63a6549..5e676c40bd56f5 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -64,8 +64,6 @@ def _handle_response(self, application_generate_entity: Union[ else: logger.exception(e) raise e - finally: - db.session.remove() def _get_conversation_by_user(self, app_model: App, conversation_id: str, user: Union[Account, EndUser]) -> Conversation: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 59a385cb38c5a3..2d032fcdcb0922 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -57,6 +57,8 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, ): return + db.session.close() + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2ac8f27bab7421..24b2f287c1319a 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,8 +5,8 @@ from os import listdir, path from typing import Any, Union +from core.agent.entities import AgentToolEntity from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.entities.application_entities import AgentToolEntity from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject diff --git a/api/models/model.py b/api/models/model.py index 6856c4e1b07d14..5a7311a0c72ecc 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -322,7 +322,7 @@ def to_dict(self) -> dict: } def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config['opening_statement'] + self.opening_statement = model_config.get('opening_statement') self.suggested_questions = json.dumps(model_config['suggested_questions']) \ if model_config.get('suggested_questions') else None self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ From 61a1aadf9ca04c09008daf9d6914d2c60ada7c42 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 16:59:17 +0800 Subject: [PATCH 102/160] optimize workflow db connections --- .../advanced_chat/generate_task_pipeline.py | 99 ++++++++++--------- .../apps/workflow/generate_task_pipeline.py | 98 +++++++++--------- .../workflow_based_generate_task_pipeline.py | 4 + 3 files changed, 105 insertions(+), 96 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 88ac5fd2357ce7..d5d3feded0328d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -59,7 +59,7 @@ class NodeExecutionInfo(BaseModel): """ NodeExecutionInfo entity """ - workflow_node_execution: WorkflowNodeExecution + workflow_node_execution_id: str start_at: float class Config: @@ -72,7 +72,7 @@ class Config: metadata: dict = {} usage: LLMUsage - workflow_run: Optional[WorkflowRun] = None + workflow_run_id: Optional[str] = None start_at: Optional[float] = None total_tokens: int = 0 total_steps: int = 0 @@ -168,8 +168,7 @@ def _process_blocking_response(self) -> dict: elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) @@ -218,8 +217,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_start() response = { 'event': 'workflow_started', @@ -234,8 +232,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_start(event) response = { 'event': 'node_started', @@ -253,8 +250,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_finished(event) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: @@ -285,8 +281,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -435,7 +430,7 @@ def _process_stream_response(self) -> Generator: else: continue - def _on_workflow_start(self) -> None: + def _on_workflow_start(self) -> WorkflowRun: self._task_state.start_at = time.perf_counter() workflow_run = self._init_workflow_run( @@ -452,11 +447,16 @@ def _on_workflow_start(self) -> None: } ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id + + db.session.close() + + return workflow_run - def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, node_id=event.node_id, node_type=event.node_type, node_title=event.node_data.title, @@ -465,19 +465,26 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> None: ) latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution=workflow_node_execution, + workflow_node_execution_id=workflow_node_execution.id, start_at=time.perf_counter() ) self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info + self._task_state.total_steps += 1 - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, @@ -495,19 +502,24 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven self._task_state.metadata['usage'] = usage_dict else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, error=event.error ) - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] - self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + + db.session.close() + + return workflow_node_execution - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -516,7 +528,7 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv ) elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -524,39 +536,30 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv error=event.error ) else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + workflow_run = self._workflow_run_success( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs - if self._task_state.latest_node_execution_info else None + outputs=outputs ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') - def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Get workflow run. - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - return workflow_run + db.session.close() - def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: - """ - Get workflow node execution. - :param workflow_node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - return workflow_node_execution + return workflow_run def _save_message(self) -> None: """ @@ -567,7 +570,7 @@ def _save_message(self) -> None: self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.workflow_run_id = self._task_state.workflow_run.id + self._message.workflow_run_id = self._task_state.workflow_run_id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9bd20f978519a9..8516feb87d5233 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -45,7 +45,7 @@ class NodeExecutionInfo(BaseModel): """ NodeExecutionInfo entity """ - workflow_node_execution: WorkflowNodeExecution + workflow_node_execution_id: str start_at: float class Config: @@ -57,7 +57,7 @@ class Config: answer: str = "" metadata: dict = {} - workflow_run: Optional[WorkflowRun] = None + workflow_run_id: Optional[str] = None start_at: Optional[float] = None total_tokens: int = 0 total_steps: int = 0 @@ -130,8 +130,7 @@ def _process_blocking_response(self) -> dict: elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) # response moderation if self._output_moderation_handler: @@ -179,8 +178,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_start() response = { 'event': 'workflow_started', @@ -195,8 +193,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_start(event) response = { 'event': 'node_started', @@ -214,8 +211,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_finished(event) response = { 'event': 'node_finished', @@ -240,8 +236,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) # response moderation if self._output_moderation_handler: @@ -257,7 +252,7 @@ def _process_stream_response(self) -> Generator: replace_response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'data': { 'text': self._task_state.answer } @@ -317,7 +312,7 @@ def _process_stream_response(self) -> Generator: response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'data': { 'text': event.text } @@ -329,7 +324,7 @@ def _process_stream_response(self) -> Generator: else: continue - def _on_workflow_start(self) -> None: + def _on_workflow_start(self) -> WorkflowRun: self._task_state.start_at = time.perf_counter() workflow_run = self._init_workflow_run( @@ -344,11 +339,16 @@ def _on_workflow_start(self) -> None: } ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id + + db.session.close() + + return workflow_run - def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, node_id=event.node_id, node_type=event.node_type, node_title=event.node_data.title, @@ -357,7 +357,7 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> None: ) latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution=workflow_node_execution, + workflow_node_execution_id=workflow_node_execution.id, start_at=time.perf_counter() ) @@ -366,11 +366,17 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> None: self._task_state.total_steps += 1 - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, @@ -383,19 +389,24 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, error=event.error ) # remove running node execution info del self._task_state.running_node_execution_infos[event.node_id] - self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -404,7 +415,7 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv ) elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -412,39 +423,30 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv error=event.error ) else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + workflow_run = self._workflow_run_success( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs - if self._task_state.latest_node_execution_info else None + outputs=outputs ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') - def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Get workflow run. - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - return workflow_run + db.session.close() - def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: - """ - Get workflow node execution. - :param workflow_node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - return workflow_node_execution + return workflow_run def _save_workflow_app_log(self) -> None: """ @@ -461,7 +463,7 @@ def _handle_chunk(self, text: str) -> dict: """ response = { 'event': 'text_chunk', - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'task_id': self._application_generate_entity.task_id, 'data': { 'text': text diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index d29cee3ac4a086..2b373d28e83957 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -87,6 +87,7 @@ def _workflow_run_success(self, workflow_run: WorkflowRun, workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -115,6 +116,7 @@ def _workflow_run_failed(self, workflow_run: WorkflowRun, workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -185,6 +187,7 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution @@ -205,6 +208,7 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution From 2d8497f79baebf0c52837eebf96077bb22df6d6d Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 17:11:39 +0800 Subject: [PATCH 103/160] add readme for db connection management in App Runner and Task Pipeline --- api/core/app/apps/README.md | 45 +++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 api/core/app/apps/README.md diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md new file mode 100644 index 00000000000000..a59c424a156ffc --- /dev/null +++ b/api/core/app/apps/README.md @@ -0,0 +1,45 @@ +## Guidelines for Database Connection Management in App Runner and Task Pipeline + +Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. + +Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors. + +Examples: + +1. Creating a new record: + + ```python + app = App(id=1) + db.session.add(app) + db.session.commit() + db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close + + # Process related app logic + + db.session.close() + + return app.id + ``` + +2. Fetching a record from the table: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + created_at = app.created_at + + db.session.close() + ``` + +3. Updating a table field: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + app.updated_at = time.utcnow() + db.session.commit() + db.session.close() + + return app_id + ``` + From 1e6feadc7ecc9987cec762befa1d9ccf7f2a9006 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 17:55:24 +0800 Subject: [PATCH 104/160] fix: code node dose not work as expected --- api/core/helper/code_executor/code_executor.py | 14 +++++++------- .../helper/code_executor/python_transformer.py | 10 ++++------ api/core/workflow/nodes/code/code_node.py | 10 +++++----- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index f1bc4fbdafde4b..fb0ad9642aeec1 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,5 +1,5 @@ from os import environ -from typing import Literal +from typing import Literal, Optional from httpx import post from pydantic import BaseModel @@ -16,8 +16,8 @@ class CodeExecutionException(Exception): class CodeExecutionResponse(BaseModel): class Data(BaseModel): - stdout: str - stderr: str + stdout: Optional[str] + error: Optional[str] code: int message: str @@ -58,9 +58,9 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: raise Exception('Failed to execute code') except CodeExecutionException as e: raise e - except Exception: + except Exception as e: raise CodeExecutionException('Failed to execute code') - + try: response = response.json() except: @@ -71,7 +71,7 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: if response.code != 0: raise CodeExecutionException(response.message) - if response.data.stderr: - raise CodeExecutionException(response.data.stderr) + if response.data.error: + raise CodeExecutionException(response.data.error) return template_transformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py index 7b862649d8a211..27863ee4435354 100644 --- a/api/core/helper/code_executor/python_transformer.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -11,11 +11,11 @@ output = main(**{{inputs}}) # convert output to json and print -result = ''' -<> +output = json.dumps(output, indent=4) + +result = f'''<> {output} -<> -''' +<>''' print(result) """ @@ -47,11 +47,9 @@ def transform_response(cls, response: str) -> dict: :param response: response :return: """ - # extract result result = re.search(r'<>(.*)<>', response, re.DOTALL) if not result: raise ValueError('Failed to parse result') - result = result.group(1) return json.loads(result) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 7d3162d9830d64..9cc58651334855 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -101,7 +101,6 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) variables[variable] = value - # Run code try: result = CodeExecutor.execute_code( @@ -109,15 +108,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: code=code, inputs=variables ) - except CodeExecutionException as e: + + # Transform result + result = self._transform_result(result, node_data.outputs) + except (CodeExecutionException, ValueError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, error=str(e) ) - # Transform result - result = self._transform_result(result, node_data.outputs) - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, From 751489fa547487bd521e4aa3a6bc297b577a2511 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 18:01:55 +0800 Subject: [PATCH 105/160] modify readme --- api/core/app/apps/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md index a59c424a156ffc..856690dc57c6c9 100644 --- a/api/core/app/apps/README.md +++ b/api/core/app/apps/README.md @@ -14,7 +14,7 @@ Examples: db.session.commit() db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close - # Process related app logic + # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). db.session.close() @@ -29,6 +29,9 @@ Examples: created_at = app.created_at db.session.close() + + # Handle tasks (include long-running). + ``` 3. Updating a table field: From 80312620064d8f946ed7fbd449aa9f0f82c8a612 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 18:41:01 +0800 Subject: [PATCH 106/160] feat: workflow mock test --- .github/workflows/api-workflow-tests.yaml | 30 +++ api/core/workflow/nodes/code/code_node.py | 10 +- api/tests/integration_tests/.env.example | 6 +- .../integration_tests/workflow/__init__.py | 0 .../workflow/nodes/__mock/code_executor.py | 27 ++ .../workflow/nodes/test_code.py | 244 ++++++++++++++++++ 6 files changed, 311 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/api-workflow-tests.yaml create mode 100644 api/tests/integration_tests/workflow/__init__.py create mode 100644 api/tests/integration_tests/workflow/nodes/__mock/code_executor.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_code.py diff --git a/.github/workflows/api-workflow-tests.yaml b/.github/workflows/api-workflow-tests.yaml new file mode 100644 index 00000000000000..e4e35c6c443819 --- /dev/null +++ b/.github/workflows/api-workflow-tests.yaml @@ -0,0 +1,30 @@ +name: Run Pytest + +on: + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + + env: + MOCK_SWITCH: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: ./api/requirements.txt + + - name: Install dependencies + run: pip install -r ./api/requirements.txt + + - name: Run pytest + run: pytest api/tests/integration_tests/workflow \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9cc58651334855..8034f4e55da709 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -132,10 +132,10 @@ def _check_string(self, value: str, variable: str) -> str: :return: """ if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") + raise ValueError(f"{variable} in output form must be a string") if len(value) > MAX_STRING_LENGTH: - raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') + raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters') return value.replace('\x00', '') @@ -147,7 +147,7 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f :return: """ if not isinstance(value, int | float): - raise ValueError(f"{variable} in input form must be a number") + raise ValueError(f"{variable} in output form must be a number") if value > MAX_NUMBER or value < MIN_NUMBER: raise ValueError(f'{variable} in input form is out of range.') @@ -205,7 +205,7 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' + f'{prefix}.{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' ) transformed_result[output_name] = [ @@ -224,7 +224,7 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' + f'{prefix}.{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters' ) transformed_result[output_name] = [ diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 04abacf73d2c17..dd1baa79d4ec9d 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -66,4 +66,8 @@ JINA_API_KEY= OLLAMA_BASE_URL= # Mock Switch -MOCK_SWITCH=false \ No newline at end of file +MOCK_SWITCH=false + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT= +CODE_EXECUTINO_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/__init__.py b/api/tests/integration_tests/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py new file mode 100644 index 00000000000000..b95c76b1338615 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -0,0 +1,27 @@ +import os +import pytest + +from typing import Literal +from _pytest.monkeypatch import MonkeyPatch +from core.helper.code_executor.code_executor import CodeExecutor + +MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + +class MockedCodeExecutor: + @classmethod + def invoke(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: + # invoke directly + if language == 'python3': + return { + "result": 3 + } + +@pytest.fixture +def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(CodeExecutor, "execute_code", MockedCodeExecutor.invoke) + yield + monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py new file mode 100644 index 00000000000000..2885b9f4585996 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -0,0 +1,244 @@ +import pytest + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.code.code_node import CodeNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowRunStatus +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode(config={ + 'id': '1', + 'data': { + 'outputs': { + 'result': { + 'type': 'number', + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + }) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + + # execute node + result = node.run(pool) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] == 3 + assert result.error is None + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code_output_validator(setup_code_executor_mock): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode(config={ + 'id': '1', + 'data': { + "outputs": { + "result": { + "type": "string", + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + }) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == 'result in output form must be a string' + +def test_execute_code_output_validator_depth(): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode(config={ + 'id': '1', + 'data': { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + } + } + } + } + } + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + }) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": "1", + "string_validator": 1, + "number_array_validator": ["1", "2", "3", "3.333"], + "string_array_validator": [1, 2, 3], + "object_validator": { + "result": "1", + "depth": { + "depth": { + "depth": "1" + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1" * 2000, + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333] * 2000, + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + \ No newline at end of file From 9d0a832e403654ae73e0857eecffa4aedc077321 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 18:41:49 +0800 Subject: [PATCH 107/160] refactor: github actions --- .github/workflows/{tool-tests.yaml => api-tools-tests.yaml} | 0 .github/workflows/api-workflow-tests.yaml | 1 + 2 files changed, 1 insertion(+) rename .github/workflows/{tool-tests.yaml => api-tools-tests.yaml} (100%) diff --git a/.github/workflows/tool-tests.yaml b/.github/workflows/api-tools-tests.yaml similarity index 100% rename from .github/workflows/tool-tests.yaml rename to .github/workflows/api-tools-tests.yaml diff --git a/.github/workflows/api-workflow-tests.yaml b/.github/workflows/api-workflow-tests.yaml index e4e35c6c443819..37a138b44dc1e0 100644 --- a/.github/workflows/api-workflow-tests.yaml +++ b/.github/workflows/api-workflow-tests.yaml @@ -4,6 +4,7 @@ on: pull_request: branches: - main + - deploy/dev jobs: test: From be6836998320c2428d3a1a9003b1ad8688c3ecbd Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 20:02:10 +0800 Subject: [PATCH 108/160] add workflow_app_log codes --- .../apps/workflow/generate_task_pipeline.py | 40 ++++++++++++++++--- api/models/workflow.py | 23 +++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 8516feb87d5233..7a244151f2ddb1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -32,7 +32,15 @@ from extensions.ext_database import db from models.account import Account from models.model import EndUser -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom +from models.workflow import ( + Workflow, + WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) logger = logging.getLogger(__name__) @@ -142,7 +150,7 @@ def _process_blocking_response(self) -> dict: ) # save workflow app log - self._save_workflow_app_log() + self._save_workflow_app_log(workflow_run) response = { 'task_id': self._application_generate_entity.task_id, @@ -261,7 +269,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(replace_response) # save workflow app log - self._save_workflow_app_log() + self._save_workflow_app_log(workflow_run) workflow_run_response = { 'event': 'workflow_finished', @@ -448,12 +456,34 @@ def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEv return workflow_run - def _save_workflow_app_log(self) -> None: + def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: """ - pass # todo + invoke_from = self._application_generate_entity.invoke_from + if invoke_from == InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + elif invoke_from == InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + elif invoke_from == InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + else: + # not save log for debugging + return + + workflow_app_log = WorkflowAppLog( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + created_from=created_from.value, + created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), + created_by=self._user.id, + ) + db.session.add(workflow_app_log) + db.session.commit() + db.session.close() def _handle_chunk(self, text: str) -> dict: """ diff --git a/api/models/workflow.py b/api/models/workflow.py index 9768c364dd66ec..5a3cdcf83c5570 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -433,6 +433,29 @@ def process_data_dict(self): def execution_metadata_dict(self): return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata) + +class WorkflowAppLogCreatedFrom(Enum): + """ + Workflow App Log Created From Enum + """ + SERVICE_API = 'service-api' + WEB_APP = 'web-app' + INSTALLED_APP = 'installed-app' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow app log created from value {value}') + + class WorkflowAppLog(db.Model): """ Workflow App execution log, excluding workflow debugging records. From a0a161886938d5d77521038daf7b4a58e07fd57b Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 20:15:49 +0800 Subject: [PATCH 109/160] add tenant_id / app_id / workflow_id for nodes --- api/core/workflow/entities/workflow_entities.py | 14 +++++++++++--- api/core/workflow/nodes/base_node.py | 13 ++++++++++++- api/core/workflow/workflow_engine_manager.py | 17 ++++++++++++++--- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 768ad6a1303a97..91f9ef95fe51e6 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -3,7 +3,7 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowType class WorkflowNodeAndResult: @@ -16,7 +16,11 @@ def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): class WorkflowRunState: - workflow: Workflow + tenant_id: str + app_id: str + workflow_id: str + workflow_type: WorkflowType + start_at: float variable_pool: VariablePool @@ -25,6 +29,10 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): - self.workflow = workflow + self.workflow_id = workflow.id + self.tenant_id = workflow.tenant_id + self.app_id = workflow.app_id + self.workflow_type = WorkflowType.value_of(workflow.type) + self.start_at = start_at self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3f2e806433e289..6db25bea7eaa0d 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -12,14 +12,25 @@ class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType + tenant_id: str + app_id: str + workflow_id: str + node_id: str node_data: BaseNodeData node_run_result: Optional[NodeRunResult] = None callbacks: list[BaseWorkflowCallback] - def __init__(self, config: dict, + def __init__(self, tenant_id: str, + app_id: str, + workflow_id: str, + config: dict, callbacks: list[BaseWorkflowCallback] = None) -> None: + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.node_id = config.get("id") if not self.node_id: raise ValueError("Node ID is required.") diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 50f79df1f069e9..d01746ceb8b691 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -122,6 +122,7 @@ def run_workflow(self, workflow: Workflow, while True: # get next node, multiple target nodes in the future next_node = self._get_next_node( + workflow_run_state=workflow_run_state, graph=graph, predecessor_node=predecessor_node, callbacks=callbacks @@ -198,7 +199,8 @@ def _workflow_run_failed(self, error: str, error=error ) - def _get_next_node(self, graph: dict, + def _get_next_node(self, workflow_run_state: WorkflowRunState, + graph: dict, predecessor_node: Optional[BaseNode] = None, callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: """ @@ -216,7 +218,13 @@ def _get_next_node(self, graph: dict, if not predecessor_node: for node_config in nodes: if node_config.get('data', {}).get('type', '') == NodeType.START.value: - return StartNode(config=node_config) + return StartNode( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + config=node_config, + callbacks=callbacks + ) else: edges = graph.get('edges') source_node_id = predecessor_node.node_id @@ -256,6 +264,9 @@ def _get_next_node(self, graph: dict, target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) return target_node( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, config=target_node_config, callbacks=callbacks ) @@ -354,7 +365,7 @@ def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, :param node_run_result: node run result :return: """ - if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END: workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] if workflow_nodes_and_result_before_end: if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: From e0883302d26262e16456ceef14a79a4837b61cb5 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 20:24:16 +0800 Subject: [PATCH 110/160] feat: jinja2 --- .../helper/code_executor/code_executor.py | 7 ++- .../helper/code_executor/jina2_transformer.py | 55 ++++++++++++++++++- .../template_transform_node.py | 6 +- .../workflow/nodes/__mock/code_executor.py | 2 +- 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index fb0ad9642aeec1..a62cf4de951120 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from httpx import post from pydantic import BaseModel from yarl import URL +from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer @@ -25,7 +26,7 @@ class Data(BaseModel): class CodeExecutor: @classmethod - def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: + def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -36,6 +37,8 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: template_transformer = None if language == 'python3': template_transformer = PythonTemplateTransformer + elif language == 'jinja2': + template_transformer = Jinja2TemplateTransformer else: raise CodeExecutionException('Unsupported language') @@ -46,7 +49,7 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: 'X-Api-Key': CODE_EXECUTION_API_KEY } data = { - 'language': language, + 'language': language if language != 'jinja2' else 'python3', 'code': runner, } diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py index f87f5c14cbbd7d..87e8ce130f2820 100644 --- a/api/core/helper/code_executor/jina2_transformer.py +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -1 +1,54 @@ -# TODO \ No newline at end of file +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +PYTHON_RUNNER = """ +import jinja2 + +template = jinja2.Template('''{{code}}''') + +def main(**inputs): + return template.render(**inputs) + +# execute main function, and return the result +output = main(**{{inputs}}) + +result = f'''<>{output}<>''' + +print(result) + +""" + +class Jinja2TemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform jinja2 template to python code + runner = PYTHON_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4)) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + + return { + 'result': result + } \ No newline at end of file diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 724b84495c53e9..a037332f4bcaf6 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -52,7 +52,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # Run code try: result = CodeExecutor.execute_code( - language='jina2', + language='jinja2', code=node_data.template, inputs=variables ) @@ -66,7 +66,9 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - outputs=result['result'] + outputs={ + 'output': result['result'] + } ) @classmethod diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index b95c76b1338615..a1c8eb71dc5f23 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -9,7 +9,7 @@ class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: + def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: # invoke directly if language == 'python3': return { From f8cba2679e4ce24667a8f365bf81631b71e5c156 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 21:12:07 +0800 Subject: [PATCH 111/160] fix: linter --- api/core/helper/code_executor/code_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index a62cf4de951120..21a8ca5f9f1ba7 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,8 +4,8 @@ from httpx import post from pydantic import BaseModel from yarl import URL -from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer +from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer # Code Executor From 5e4bd9fc38ba406569099e9b49965a80ae5ef615 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 13:54:11 +0800 Subject: [PATCH 112/160] feat: tool node --- api/core/agent/base_agent_runner.py | 69 ---------- api/core/agent/cot_agent_runner.py | 8 +- api/core/agent/fc_agent_runner.py | 8 +- api/core/tools/tool_manager.py | 114 ++++++++++------ api/core/tools/utils/message_transformer.py | 85 ++++++++++++ api/core/workflow/nodes/tool/entities.py | 23 ++++ api/core/workflow/nodes/tool/tool_node.py | 136 +++++++++++++++++++- 7 files changed, 334 insertions(+), 109 deletions(-) create mode 100644 api/core/tools/utils/message_transformer.py create mode 100644 api/core/workflow/nodes/tool/entities.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0901b7e96598a0..14602a72656b39 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,7 +2,6 @@ import logging import uuid from datetime import datetime -from mimetypes import guess_extension from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -39,7 +38,6 @@ ) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool -from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_manager import ToolManager from extensions.ext_database import db from models.model import Message, MessageAgentThought, MessageFile @@ -462,73 +460,6 @@ def save_agent_thought(self, db.session.commit() db.session.close() - - def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: - """ - Transform tool message into agent thought - """ - result = [] - - for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.IMAGE: - # try to download image - try: - file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_url=message.message) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - except Exception as e: - logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') - # if message is str, encode it to bytes - if isinstance(message.message, str): - message.message = message.message.encode('utf-8') - file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_binary=message.message, - mimetype=mimetype) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' - - # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(message) - - return result def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index cbb19aca53ad75..0c5399f5416d65 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,6 +25,7 @@ ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message @@ -280,7 +281,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_parameters=tool_call_args ) # transform tool response to llm friendly response - tool_response = self.transform_tool_invoke_messages(tool_response) + tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_response, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_response) # create message file diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 7c3849a12ca0c4..185d7684c82ad4 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -23,6 +23,7 @@ ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -270,7 +271,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_parameters=tool_call_args, ) # transform tool invoke message to get LLM friendly message - tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message) + tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_invoke_message, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_invoke_message) # create message file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 24b2f287c1319a..ea66362195b719 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -34,6 +34,7 @@ ToolParameterConfigurationManager, ) from core.tools.utils.encoder import serialize_base_model_dict +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -225,6 +226,48 @@ def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, ten else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + @staticmethod + def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + """ + init runtime parameter + """ + parameter_value = parameters.get(parameter_rule.name) + if not parameter_value: + # get default value + parameter_value = parameter_rule.default + if not parameter_value and parameter_rule.required: + raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config") + + if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter_rule.options)) + if parameter_value not in options: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(parameter_value, int): + parameter_value = parameter_value + elif isinstance(parameter_value, float): + parameter_value = parameter_value + elif isinstance(parameter_value, str): + if '.' in parameter_value: + parameter_value = float(parameter_value) + else: + parameter_value = int(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: + parameter_value = bool(parameter_value) + elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + parameter_value = str(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType: + parameter_value = str(parameter_value) + except Exception as e: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type") + + return parameter_value + @staticmethod def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: """ @@ -239,44 +282,9 @@ def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_ca parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: - # get tool parameter from form - tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) - if not tool_parameter_config: - # get default value - tool_parameter_config = parameter.default - if not tool_parameter_config and parameter.required: - raise ValueError(f"tool parameter {parameter.name} not found in tool config") - - if parameter.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter.options)) - if tool_parameter_config not in options: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") - - # convert tool parameter config to correct type - try: - if parameter.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(tool_parameter_config, int): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, float): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, str): - if '.' in tool_parameter_config: - tool_parameter_config = float(tool_parameter_config) - else: - tool_parameter_config = int(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - tool_parameter_config = bool(tool_parameter_config) - elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: - tool_parameter_config = str(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType: - tool_parameter_config = str(tool_parameter_config) - except Exception as e: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") - # save tool parameter to tool entity memory - runtime_parameters[parameter.name] = tool_parameter_config + value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) + runtime_parameters[parameter.name] = value # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( @@ -289,6 +297,38 @@ def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_ca tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity + + @staticmethod + def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler): + """ + get the workflow tool runtime + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=workflow_tool.provider_type, + provider_name=workflow_tool.provider_id, + tool_name=workflow_tool.tool_name, + tenant_id=tenant_id, + agent_callback=agent_callback + ) + runtime_parameters = {} + parameters = tool_entity.get_all_runtime_parameters() + + for parameter in parameters: + # save tool parameter to tool entity memory + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=workflow_tool.provider_id, + provider_type=workflow_tool.provider_type, + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity @staticmethod def get_builtin_provider_icon(provider: str) -> tuple[str, str]: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py new file mode 100644 index 00000000000000..3f456b4eb6adbe --- /dev/null +++ b/api/core/tools/utils/message_transformer.py @@ -0,0 +1,85 @@ +import logging +from mimetypes import guess_extension + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_file_manager import ToolFileManager + +logger = logging.getLogger(__name__) + +class ToolFileMessageTransformer: + @staticmethod + def transform_tool_invoke_messages(messages: list[ToolInvokeMessage], + user_id: str, + tenant_id: str, + conversation_id: str) -> list[ToolInvokeMessage]: + """ + Transform tool message and handle file download + """ + result = [] + + for message in messages: + if message.type == ToolInvokeMessage.MessageType.TEXT: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.LINK: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.IMAGE: + # try to download image + try: + file = ToolFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_url=message.message + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' + + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + except Exception as e: + logger.exception(e) + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + )) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + mimetype = message.meta.get('mime_type', 'octet/stream') + # if message is str, encode it to bytes + if isinstance(message.message, str): + message.message = message.message.encode('utf-8') + + file = ToolFileManager.create_file_by_raw( + user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message, + mimetype=mimetype + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + + # check if file is image + if 'image' in mimetype: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(message) + + return result \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py new file mode 100644 index 00000000000000..e782bd30044790 --- /dev/null +++ b/api/core/workflow/nodes/tool/entities.py @@ -0,0 +1,23 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + +ToolParameterValue = Union[str, int, float, bool] + +class ToolEntity(BaseModel): + provider_id: str + provider_type: Literal['builtin', 'api'] + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_parameters: dict[str, ToolParameterValue] + + +class ToolNodeData(BaseNodeData, ToolEntity): + """ + Tool Node Schema + """ + tool_inputs: list[VariableSelector] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b805a53d2f3719..a0b0991eb6c0b5 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,139 @@ +from os import path +from typing import cast + +from core.file.file_obj import FileTransferMethod +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.tool.entities import ToolNodeData +from models.workflow import WorkflowNodeExecutionStatus class ToolNode(BaseNode): - pass + """ + Tool Node + """ + _node_data_cls = ToolNodeData + _node_type = NodeType.TOOL + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run the tool node + """ + + node_data = cast(ToolNodeData, self.node_data) + + # extract tool parameters + parameters = { + k.variable: variable_pool.get_variable_value(k.value_selector) + for k in node_data.tool_inputs + } + + if len(parameters) != len(node_data.tool_inputs): + raise ValueError('Invalid tool parameters') + + # get tool runtime + try: + tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to get tool runtime: {str(e)}' + ) + + try: + messages = tool_runtime.invoke(None, parameters) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to invoke tool: {str(e)}' + ) + + # convert tool messages + plain_text, files = self._convert_tool_messages(messages) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCESS, + outputs={ + 'text': plain_text, + 'files': files + }, + ) + + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages) + # extract plain text and files + files = self._extract_tool_response_binary(messages) + plain_text = self._extract_tool_response_text(messages) + + return plain_text, files + + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + """ + Extract tool response binary + """ + result = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + url = response.message + ext = path.splitext(url)[1] + mimetype = response.meta.get('mime_type', 'image/jpeg') + filename = response.save_as or url.split('/')[-1] + result.append({ + 'type': 'image', + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': url, + 'upload_file_id': None, + 'filename': filename, + 'file-ext': ext, + 'mime-type': mimetype, + }) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append({ + 'type': 'image', # TODO: only support image for now + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': response.message, + 'upload_file_id': None, + 'filename': response.save_as, + 'file-ext': path.splitext(response.save_as)[1], + 'mime-type': response.meta.get('mime_type', 'application/octet-stream'), + }) + elif response.type == ToolInvokeMessage.MessageType.LINK: + pass # TODO: + + return result + + def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: + """ + Extract tool response text + """ + return ''.join([ + f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else + f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' + for message in tool_response + ]) + + def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict: + """ + Convert ToolInvokeMessage into file + """ + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + """ + pass \ No newline at end of file From 5eb7b4d56a93acbd9cbfbd52af44d84e0ab3d76a Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:13:52 +0800 Subject: [PATCH 113/160] feat: tool entity --- api/core/tools/tool_manager.py | 2 +- api/core/workflow/nodes/tool/entities.py | 19 +++++++++++---- api/core/workflow/nodes/tool/tool_node.py | 29 ++++++++++++----------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ea66362195b719..52e1e71d8240e1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -315,7 +315,7 @@ def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_c for parameter in parameters: # save tool parameter to tool entity memory - value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) runtime_parameters[parameter.name] = value # decrypt runtime parameters diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index e782bd30044790..0b3bf76aacfd2a 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,6 +1,6 @@ -from typing import Literal, Union +from typing import Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -13,11 +13,20 @@ class ToolEntity(BaseModel): provider_name: str # redundancy tool_name: str tool_label: str # redundancy - tool_parameters: dict[str, ToolParameterValue] - + tool_configurations: dict[str, ToolParameterValue] class ToolNodeData(BaseNodeData, ToolEntity): + class ToolInput(VariableSelector): + variable_type: Literal['selector', 'static'] + value: Optional[str] + + @validator('value') + def check_value(cls, value, values, **kwargs): + if values['variable_type'] == 'static' and value is None: + raise ValueError('value is required for static variable') + return value + """ Tool Node Schema """ - tool_inputs: list[VariableSelector] + tool_parameters: list[ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index a0b0991eb6c0b5..f1897780f2c777 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -27,14 +27,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data = cast(ToolNodeData, self.node_data) - # extract tool parameters - parameters = { - k.variable: variable_pool.get_variable_value(k.value_selector) - for k in node_data.tool_inputs - } - - if len(parameters) != len(node_data.tool_inputs): - raise ValueError('Invalid tool parameters') + # get parameters + parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: @@ -47,6 +41,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) try: + # TODO: user_id messages = tool_runtime.invoke(None, parameters) except Exception as e: return NodeRunResult( @@ -59,12 +54,23 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: plain_text, files = self._convert_tool_messages(messages) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCESS, + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ 'text': plain_text, 'files': files }, ) + + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: + """ + Generate parameters + """ + return { + k.variable: + k.value if k.variable_type == 'static' else + variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else '' + for k in node_data.tool_parameters + } def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: """ @@ -125,11 +131,6 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> for message in tool_response ]) - def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict: - """ - Convert ToolInvokeMessage into file - """ - pass @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: From 7a6fa3655f648935a5e0b82d4458a1263a98734f Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 16:31:43 +0800 Subject: [PATCH 114/160] add user for node --- api/core/app/apps/advanced_chat/app_runner.py | 6 +++++ api/core/app/apps/workflow/app_runner.py | 6 +++++ .../workflow/entities/workflow_entities.py | 12 +++++++-- api/core/workflow/nodes/base_node.py | 27 +++++++++++++++++++ api/core/workflow/workflow_engine_manager.py | 14 ++++++++-- .../unit_tests/core/workflow/__init__.py | 0 6 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/__init__.py diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c42620b92f9266..5f5fd7010c1c42 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -8,10 +8,12 @@ from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App, Conversation, Message @@ -78,6 +80,10 @@ def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2d032fcdcb0922..922c3003bfba41 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -7,12 +7,14 @@ from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, + InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App @@ -63,6 +65,10 @@ def run(self, application_generate_entity: WorkflowAppGenerateEntity, workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 91f9ef95fe51e6..a78bf09a531c14 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -2,7 +2,7 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from models.workflow import Workflow, WorkflowType @@ -20,6 +20,8 @@ class WorkflowRunState: app_id: str workflow_id: str workflow_type: WorkflowType + user_id: str + user_from: UserFrom start_at: float variable_pool: VariablePool @@ -28,11 +30,17 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): + def __init__(self, workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id self.workflow_type = WorkflowType.value_of(workflow.type) + self.user_id = user_id + self.user_from = user_from self.start_at = start_at self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6db25bea7eaa0d..a603f484ef4d84 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Optional from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback @@ -8,6 +9,26 @@ from models.workflow import WorkflowNodeExecutionStatus +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") + + class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType @@ -15,6 +36,8 @@ class BaseNode(ABC): tenant_id: str app_id: str workflow_id: str + user_id: str + user_from: UserFrom node_id: str node_data: BaseNodeData @@ -25,11 +48,15 @@ class BaseNode(ABC): def __init__(self, tenant_id: str, app_id: str, workflow_id: str, + user_id: str, + user_from: UserFrom, config: dict, callbacks: list[BaseWorkflowCallback] = None) -> None: self.tenant_id = tenant_id self.app_id = app_id self.workflow_id = workflow_id + self.user_id = user_id + self.user_from = user_from self.node_id = config.get("id") if not self.node_id: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index d01746ceb8b691..0bc13cbb5a7b98 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -6,7 +6,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode @@ -76,12 +76,16 @@ def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None return default_config def run_workflow(self, workflow: Workflow, + user_id: str, + user_from: UserFrom, user_inputs: dict, system_inputs: Optional[dict] = None, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param workflow: Workflow instance + :param user_id: user id + :param user_from: user from :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks @@ -113,7 +117,9 @@ def run_workflow(self, workflow: Workflow, variable_pool=VariablePool( system_variables=system_inputs, user_inputs=user_inputs - ) + ), + user_id=user_id, + user_from=user_from ) try: @@ -222,6 +228,8 @@ def _get_next_node(self, workflow_run_state: WorkflowRunState, tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, config=node_config, callbacks=callbacks ) @@ -267,6 +275,8 @@ def _get_next_node(self, workflow_run_state: WorkflowRunState, tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, config=target_node_config, callbacks=callbacks ) diff --git a/api/tests/unit_tests/core/workflow/__init__.py b/api/tests/unit_tests/core/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 From f911b1c488ccc18eaf274a6fa4c4869f57b6cf21 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:44:22 +0800 Subject: [PATCH 115/160] feat: support empty code output children --- api/core/workflow/nodes/code/code_node.py | 53 ++++- api/core/workflow/nodes/code/entities.py | 4 +- .../workflow/nodes/test_code.py | 206 ++++++++++-------- 3 files changed, 167 insertions(+), 96 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 8034f4e55da709..bfdec73199b886 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -153,11 +153,13 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f raise ValueError(f'{variable} in input form is out of range.') if isinstance(value, float): - value = round(value, MAX_PRECISION) + # raise error if precision is too high + if len(str(value).split('.')[1]) > MAX_PRECISION: + raise ValueError(f'{variable} in output form has too high precision.') return value - def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = '', depth: int = 1) -> dict: """ @@ -170,6 +172,47 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. raise ValueError("Depth limit reached, object too deep.") transformed_result = {} + if output_schema is None: + # validate output thought instance type + for output_name, output_value in result.items(): + if isinstance(output_value, dict): + self._transform_result( + result=output_value, + output_schema=None, + prefix=f'{prefix}.{output_name}' if prefix else output_name, + depth=depth + 1 + ) + elif isinstance(output_value, (int, float)): + self._check_number( + value=output_value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif isinstance(output_value, str): + self._check_string( + value=output_value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif isinstance(output_value, list): + if all(isinstance(value, (int, float)) for value in output_value): + for value in output_value: + self._check_number( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif all(isinstance(value, str) for value in output_value): + for value in output_value: + self._check_string( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + else: + raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + else: + raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') + + return result + + parameters_validated = {} for output_name, output_config in output_schema.items(): if output_config.type == 'object': # check if output is object @@ -236,6 +279,12 @@ def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData. ] else: raise ValueError(f'Output type {output_config.type} is not supported.') + + parameters_validated[output_name] = True + + # check if all output parameters are validated + if len(parameters_validated) != len(result): + raise ValueError('Not all output parameters are validated.') return transformed_result diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 6a18d181cb2ecc..ec3e3fe530d286 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal, Optional from pydantic import BaseModel @@ -12,7 +12,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] - children: Union[None, dict[str, 'Output']] + children: Optional[dict[str, 'Output']] variables: list[VariableSelector] answer: str diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 2885b9f4585996..0b7217b053b067 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,8 +1,9 @@ import pytest +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.code.code_node import CodeNode -from models.workflow import WorkflowNodeExecutionStatus, WorkflowRunStatus +from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) @@ -15,30 +16,37 @@ def main(args1: int, args2: int) -> dict: ''' # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) - node = CodeNode(config={ - 'id': '1', - 'data': { - 'outputs': { - 'result': { - 'type': 'number', + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'outputs': { + 'result': { + 'type': 'number', + }, }, - }, - 'title': '123', - 'variables': [ - { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } - ], - 'answer': '123', - 'code_language': 'python3', - 'code': code + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } } - }) + ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}) @@ -61,30 +69,37 @@ def main(args1: int, args2: int) -> dict: ''' # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) - node = CodeNode(config={ - 'id': '1', - 'data': { - "outputs": { - "result": { - "type": "string", - }, - }, - 'title': '123', - 'variables': [ - { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + "outputs": { + "result": { + "type": "string", + }, }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } - ], - 'answer': '123', - 'code_language': 'python3', - 'code': code + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } } - }) + ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}) @@ -108,60 +123,67 @@ def main(args1: int, args2: int) -> dict: ''' # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) - node = CodeNode(config={ - 'id': '1', - 'data': { - "outputs": { - "string_validator": { - "type": "string", - }, - "number_validator": { - "type": "number", - }, - "number_array_validator": { - "type": "array[number]", - }, - "string_array_validator": { - "type": "array[string]", - }, - "object_validator": { - "type": "object", - "children": { - "result": { - "type": "number", - }, - "depth": { - "type": "object", - "children": { - "depth": { - "type": "object", - "children": { - "depth": { - "type": "number", + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } } } } } } - } + }, }, - }, - 'title': '123', - 'variables': [ - { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } - ], - 'answer': '123', - 'code_language': 'python3', - 'code': code + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } } - }) + ) # construct result result = { From 91845fc9f6e652b1f6dd327abfc3870df373c295 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:44:36 +0800 Subject: [PATCH 116/160] fix: linter --- api/core/workflow/nodes/code/code_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index bfdec73199b886..2f22a386e55880 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -182,7 +182,7 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code prefix=f'{prefix}.{output_name}' if prefix else output_name, depth=depth + 1 ) - elif isinstance(output_value, (int, float)): + elif isinstance(output_value, int | float): self._check_number( value=output_value, variable=f'{prefix}.{output_name}' if prefix else output_name @@ -193,7 +193,7 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code variable=f'{prefix}.{output_name}' if prefix else output_name ) elif isinstance(output_value, list): - if all(isinstance(value, (int, float)) for value in output_value): + if all(isinstance(value, int | float) for value in output_value): for value in output_value: self._check_number( value=value, From 407bfb8182ee32c2057ae2081c2d8dbc895d5c01 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:46:11 +0800 Subject: [PATCH 117/160] feat: add user uid --- api/core/workflow/nodes/tool/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index f1897780f2c777..b0bc1246bd6949 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -42,7 +42,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: try: # TODO: user_id - messages = tool_runtime.invoke(None, parameters) + messages = tool_runtime.invoke(self.user_id, parameters) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, From f318fa058ccd95cb996c64663dbfcf4a1271e220 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:48:28 +0800 Subject: [PATCH 118/160] feat: add variable selector mapping --- api/core/workflow/nodes/tool/tool_node.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b0bc1246bd6949..bfa7db3943eb3b 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -137,4 +137,8 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) """ Extract variable selector to variable mapping """ - pass \ No newline at end of file + return { + k.value_selector: k.variable + for k in cast(ToolNodeData, node_data).tool_parameters + if k.variable_type == 'selector' + } \ No newline at end of file From 88c29f613f8d01be1fcb01a0a1ba8bfee78cb6f7 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:51:27 +0800 Subject: [PATCH 119/160] fix: typing --- api/core/workflow/nodes/code/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index ec3e3fe530d286..0e2b3c99bfb4a8 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -12,7 +12,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] - children: Optional[dict[str, 'Output']] + children: Optional[dict[str, 'CodeNodeData.Output']] variables: list[VariableSelector] answer: str From 33113034ea6ad02a8b59f5efe7645824ad6bedc3 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 18:49:22 +0800 Subject: [PATCH 120/160] add single step run --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/workflow.py | 21 +++-- api/core/workflow/errors.py | 10 +++ api/core/workflow/nodes/base_node.py | 4 +- api/core/workflow/nodes/code/code_node.py | 6 +- .../nodes/direct_answer/direct_answer_node.py | 10 ++- api/core/workflow/nodes/end/end_node.py | 2 +- .../nodes/http_request/http_request_node.py | 6 +- api/core/workflow/nodes/llm/llm_node.py | 2 +- api/core/workflow/nodes/start/start_node.py | 2 +- .../template_transform_node.py | 4 +- api/core/workflow/nodes/tool/tool_node.py | 6 +- api/core/workflow/workflow_engine_manager.py | 88 +++++++++++++++++++ api/fields/workflow_run_fields.py | 8 +- api/services/workflow_run_service.py | 14 +-- api/services/workflow_service.py | 86 +++++++++++++++++- 16 files changed, 232 insertions(+), 39 deletions(-) create mode 100644 api/core/workflow/errors.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a6f803785ab2f4..853ca9e3a7ca4d 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic, workflow, workflow_app_log) + model_config, site, statistic, workflow, workflow_run, workflow_app_log) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5f03a7cd377744..6f81da569126f3 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -15,6 +15,7 @@ from controllers.console.wraps import account_initialization_required from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields +from fields.workflow_run_fields import workflow_run_node_execution_fields from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode @@ -164,18 +165,24 @@ class DraftWorkflowNodeRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) def post(self, app_model: App, node_id: str): """ Run draft workflow node """ - # TODO + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + workflow_service = WorkflowService() - workflow_service.run_draft_workflow_node(app_model=app_model, node_id=node_id, account=current_user) + workflow_node_execution = workflow_service.run_draft_workflow_node( + app_model=app_model, + node_id=node_id, + user_inputs=args.get('inputs'), + account=current_user + ) - # TODO - return { - "result": "success" - } + return workflow_node_execution class PublishedWorkflowApi(Resource): @@ -291,7 +298,7 @@ def generate() -> Generator: api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') +api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py new file mode 100644 index 00000000000000..fe79fadf66b876 --- /dev/null +++ b/api/core/workflow/errors.py @@ -0,0 +1,10 @@ +from core.workflow.entities.node_entities import NodeType + + +class WorkflowNodeRunFailedError(Exception): + def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): + self.node_id = node_id + self.node_type = node_type + self.node_title = node_title + self.error = error + super().__init__(f"Node {node_title} run failed: {error}") diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a603f484ef4d84..dfba9d0385ee0b 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -108,7 +108,7 @@ def publish_text_chunk(self, text: str) -> None: ) @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict: + def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param config: node config @@ -119,7 +119,7 @@ def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict: @classmethod @abstractmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2f22a386e55880..2c11e5ba00b9ac 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -289,7 +289,7 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data @@ -297,5 +297,5 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) """ return { - variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables - } \ No newline at end of file + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 9193bab9ee5463..fedbc9b2d1c86a 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -50,10 +50,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - return {} + node_data = cast(cls._node_data_cls, node_data) + + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + return variable_mapping diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 65b0b86aa0314f..2666ccc4f97a04 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -56,7 +56,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 4ee76deb83eaa8..853f8fe5e37f2c 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -48,12 +48,12 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ return { - variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables - } \ No newline at end of file + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 90a7755b85de81..41e28937ac7d2a 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -23,7 +23,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: pass @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 2321e04bd4256e..08171457fbb4dc 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -69,7 +69,7 @@ def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict return filtered_inputs @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index a037332f4bcaf6..c41f5d1030a6bc 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -72,12 +72,12 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ return { - variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index bfa7db3943eb3b..69a97fc206172b 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -133,12 +133,12 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping """ return { - k.value_selector: k.variable + k.variable: k.value_selector for k in cast(ToolNodeData, node_data).tool_parameters if k.variable_type == 'selector' - } \ No newline at end of file + } diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0bc13cbb5a7b98..17225c19ea0bb9 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -6,6 +6,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState +from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -180,6 +181,93 @@ def run_workflow(self, workflow: Workflow, callbacks=callbacks ) + def single_step_run_workflow_node(self, workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get('id') == node_id: + node_config = node + break + + if not node_config: + raise ValueError('node id not found in workflow graph') + + # Get node class + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + + # init workflow run state + node_instance = node_cls( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + config=node_config + ) + + try: + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={} + ) + + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) + except NotImplementedError: + variable_mapping = {} + + for variable_key, variable_selector in variable_mapping.items(): + if variable_key not in user_inputs: + raise ValueError(f'Variable key {variable_key} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + + # append variable and value to variable pool + variable_pool.append_variable( + node_id=variable_node_id, + variable_key_list=variable_key_list, + value=user_inputs.get(variable_key) + ) + + # run node + node_run_result = node_instance.run( + variable_pool=variable_pool + ) + except Exception as e: + raise WorkflowNodeRunFailedError( + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_title=node_instance.node_data.title, + error=str(e) + ) + + return node_instance, node_run_result + + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 572f472f1f083f..3135d91fd3d3d6 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -34,11 +34,9 @@ } workflow_run_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='items') + 'limit': fields.Integer(attribute='limit'), + 'has_more': fields.Boolean(attribute='has_more'), + 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') } workflow_run_detail_fields = { diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 70ce1f2ce0406e..1d3f93f2247d19 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -34,26 +34,26 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro if not last_workflow_run: raise ValueError('Last workflow run not exists') - conversations = base_query.filter( + workflow_runs = base_query.filter( WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() else: - conversations = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() has_more = False - if len(conversations) == limit: - current_page_first_conversation = conversations[-1] + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( - WorkflowRun.created_at < current_page_first_conversation.created_at, - WorkflowRun.id != current_page_first_conversation.id + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id ).count() if rest_count > 0: has_more = True return InfiniteScrollPagination( - data=conversations, + data=workflow_runs, limit=limit, has_more=has_more ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index f8bd80a0b1c125..2c9c07106cec94 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,4 +1,5 @@ import json +import time from collections.abc import Generator from datetime import datetime from typing import Optional, Union @@ -9,12 +10,21 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType +from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, EndUser -from models.workflow import Workflow, WorkflowType +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowType, +) from services.workflow.workflow_converter import WorkflowConverter @@ -214,6 +224,80 @@ def stop_workflow_task(self, task_id: str, """ AppQueueManager.set_stop_flag(task_id, invoke_from, user.id) + def run_draft_workflow_node(self, app_model: App, + node_id: str, + user_inputs: dict, + account: Account) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow node + workflow_engine_manager = WorkflowEngineManager() + start_at = time.perf_counter() + + try: + node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + except WorkflowNodeRunFailedError as e: + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=e.node_id, + node_type=e.node_type.value, + title=e.node_title, + status=WorkflowNodeExecutionStatus.FAILED.value, + error=e.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, + process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, + outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, + execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) + if node_run_result.metadata else None), + status=WorkflowNodeExecutionStatus.SUCCEEDED.value, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ Basic mode of chatbot app(expert mode) to workflow From f2bb0012fdc980c989c0805a27deedc35ad06388 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 18:52:24 +0800 Subject: [PATCH 121/160] add debug code --- api/core/workflow/nodes/direct_answer/direct_answer_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index fedbc9b2d1c86a..22ef2ed53b7e61 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -39,7 +39,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) + time.sleep(10) # TODO for debug return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, From 7f7269d261349027dd93661b0d82c6f71ab5bef7 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 19:04:48 +0800 Subject: [PATCH 122/160] remove unused params in workflow_run_for_list_fields --- api/fields/workflow_run_fields.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 3135d91fd3d3d6..72510cd27ac621 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -20,11 +20,7 @@ "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.Raw(attribute='graph_dict'), - "inputs": fields.Raw(attribute='inputs_dict'), "status": fields.String, - "outputs": fields.Raw(attribute='outputs_dict'), - "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, From 7372776992ac2fd2e5976be1a5396ad6503e06ea Mon Sep 17 00:00:00 2001 From: jyong Date: Mon, 11 Mar 2024 20:06:38 +0800 Subject: [PATCH 123/160] knowledge node --- .../knowledge_retrieval/knowledge_retrieval_node.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index c6dd6249216faa..7b8344418b2fa0 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,5 +1,13 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode class KnowledgeRetrievalNode(BaseNode): - pass + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + pass From ebf9c41adb68008d88f61b896bdebdf84ae337f4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 18:02:20 +0800 Subject: [PATCH 124/160] feat: http --- api/core/helper/ssrf_proxy.py | 1 + .../workflow/nodes/http_request/entities.py | 4 +- .../nodes/http_request/http_executor.py | 88 +++++++++++-------- .../nodes/http_request/http_request_node.py | 4 +- .../workflow/nodes/__mock/http.py | 82 +++++++++++++++++ .../workflow/nodes/test_http.py | 51 +++++++++++ 6 files changed, 191 insertions(+), 39 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/__mock/http.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_http.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index c44d4717e6da33..22f5fe57e0fc24 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -26,6 +26,7 @@ } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None def get(url, *args, **kwargs): + print(url, kwargs) return _get(url=url, *args, proxies=httpx_proxies, **kwargs) def post(url, *args, **kwargs): diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 1e906cbaa4b777..ce806b6bdbad85 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal, Optional, Union from pydantic import BaseModel @@ -29,4 +29,4 @@ class Body(BaseModel): authorization: Authorization headers: str params: str - body: Body \ No newline at end of file + body: Optional[Body] \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 82d879a89ccb4e..6134a7d780b2ce 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -76,11 +76,17 @@ def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, An # fill in params kv_paris = original_params.split('\n') for kv in kv_paris: + if not kv.strip(): + continue + kv = kv.split(':') - if len(kv) != 2: + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: raise ValueError(f'Invalid params {kv}') - k, v = kv self.params[k] = v # extract all template in headers @@ -96,51 +102,61 @@ def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, An # fill in headers kv_paris = original_headers.split('\n') for kv in kv_paris: + if not kv.strip(): + continue + kv = kv.split(':') - if len(kv) != 2: + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: raise ValueError(f'Invalid headers {kv}') - k, v = kv self.headers[k] = v # extract all template in body - body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] - body_template = list(set(body_template)) - original_body = node_data.body.data or '' - for body in body_template: - if not body: - continue + if node_data.body: + body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] + body_template = list(set(body_template)) + original_body = node_data.body.data or '' + for body in body_template: + if not body: + continue + + original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) + + if node_data.body.type == 'json': + self.headers['Content-Type'] = 'application/json' + elif node_data.body.type == 'x-www-form-urlencoded': + self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + # elif node_data.body.type == 'form-data': + # self.headers['Content-Type'] = 'multipart/form-data' - original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) - - if node_data.body.type == 'json': - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded': - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' - # elif node_data.body.type == 'form-data': - # self.headers['Content-Type'] = 'multipart/form-data' - - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: - body = {} - kv_paris = original_body.split('\n') - for kv in kv_paris: - kv = kv.split(':') - if len(kv) != 2: - raise ValueError(f'Invalid body {kv}') - body[kv[0]] = kv[1] - - if node_data.body.type == 'form-data': - self.files = { - k: ('', v) for k, v in body.items() - } + if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + body = {} + kv_paris = original_body.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) == 2: + body[kv[0]] = kv[1] + elif len(kv) == 1: + body[kv[0]] = '' + else: + raise ValueError(f'Invalid body {kv}') + + if node_data.body.type == 'form-data': + self.files = { + k: ('', v) for k, v in body.items() + } + else: + self.body = urlencode(body) else: - self.body = urlencode(body) - else: - self.body = original_body + self.body = original_body def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) - headers = deepcopy(self.headers) or [] + headers = deepcopy(self.headers) or {} if self.authorization.type == 'api-key': if self.authorization.config.api_key is None: raise ValueError('api_key is required') diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 853f8fe5e37f2c..1ef6f4b66d9ed5 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -24,10 +24,12 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # init http executor try: http_executor = HttpExecutor(node_data=node_data, variables=variables) - # invoke http executor + # invoke http executor response = http_executor.invoke() except Exception as e: + import traceback + print(traceback.format_exc()) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py new file mode 100644 index 00000000000000..3c2b0cebfc4c86 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -0,0 +1,82 @@ +import os +import pytest +import requests.api as requests +import httpx._api as httpx +from requests import Response as RequestsResponse +from yarl import URL + +from typing import Literal +from _pytest.monkeypatch import MonkeyPatch +from json import dumps + +MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + +class MockedHttp: + def requests_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + url: str, **kwargs) -> RequestsResponse: + """ + Mocked requests.request + """ + response = RequestsResponse() + response.url = str(URL(url) % kwargs.get('params', {})) + response.headers = kwargs.get('headers', {}) + + if url == 'http://404.com': + response.status_code = 404 + response._content = b'Not Found' + return response + + # get data, files + data = kwargs.get('data', None) + files = kwargs.get('files', None) + + if data is not None: + resp = dumps(data).encode('utf-8') + if files is not None: + resp = dumps(files).encode('utf-8') + else: + resp = b'OK' + + response.status_code = 200 + response._content = resp + return response + + def httpx_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + url: str, **kwargs) -> httpx.Response: + """ + Mocked httpx.request + """ + response = httpx.Response() + response.url = str(URL(url) % kwargs.get('params', {})) + response.headers = kwargs.get('headers', {}) + + if url == 'http://404.com': + response.status_code = 404 + response.content = b'Not Found' + return response + + # get data, files + data = kwargs.get('data', None) + files = kwargs.get('files', None) + + if data is not None: + resp = dumps(data).encode('utf-8') + if files is not None: + resp = dumps(files).encode('utf-8') + else: + resp = b'OK' + + response.status_code = 200 + response.content = resp + return response + +@pytest.fixture +def setup_http_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(requests, "request", MockedHttp.requests_request) + monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) + yield + monkeypatch.undo() \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py new file mode 100644 index 00000000000000..25c293d5635e52 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -0,0 +1,51 @@ +from calendar import c +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode + +from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock + +BASIC_NODE_DATA = { + 'tenant_id': '1', + 'app_id': '1', + 'workflow_id': '1', + 'user_id': '1', + 'user_from': InvokeFrom.WEB_APP, +} + +# construct variable pool +pool = VariablePool(system_variables={}, user_inputs={}) +pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) +pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_get_param(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': '', + 'params': '', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + print(result) + + assert 1==2 \ No newline at end of file From d3385a2715d8eeeee9d705cc0438283993d07aaa Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 19:51:31 +0800 Subject: [PATCH 125/160] feat --- api/core/helper/ssrf_proxy.py | 1 - .../nodes/http_request/http_executor.py | 19 +- .../nodes/http_request/http_request_node.py | 10 +- .../workflow/nodes/__mock/http.py | 15 +- .../workflow/nodes/test_http.py | 172 +++++++++++++++++- 5 files changed, 197 insertions(+), 20 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 22f5fe57e0fc24..c44d4717e6da33 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -26,7 +26,6 @@ } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None def get(url, *args, **kwargs): - print(url, kwargs) return _get(url=url, *args, proxies=httpx_proxies, **kwargs) def post(url, *args, **kwargs): diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 6134a7d780b2ce..c96d5f07d19a82 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -43,6 +43,7 @@ def __init__(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): self.params = {} self.headers = {} self.body = None + self.files = None # init template self._init_template(node_data, variables) @@ -248,10 +249,24 @@ def to_raw_request(self) -> str: server_url += f'?{urlencode(self.params)}' raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' - for k, v in self.headers.items(): + + headers = self._assembling_headers() + for k, v in headers.items(): raw_request += f'{k}: {v}\n' raw_request += '\n' - raw_request += self.body or '' + + # if files, use multipart/form-data with boundary + if self.files: + boundary = '----WebKitFormBoundary7MA4YWxkTrZu0gW' + raw_request = f'--{boundary}\n' + raw_request + for k, v in self.files.items(): + raw_request += f'Content-Disposition: form-data; name="{k}"; filename="{v[0]}"\n' + raw_request += f'Content-Type: {v[1]}\n\n' + raw_request += v[1] + '\n' + raw_request += f'--{boundary}\n' + raw_request += '--\n' + else: + raw_request += self.body or '' return raw_request \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 1ef6f4b66d9ed5..c83e331fa8f4c3 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -28,13 +28,13 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # invoke http executor response = http_executor.invoke() except Exception as e: - import traceback - print(traceback.format_exc()) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), - process_data=http_executor.to_raw_request() + process_data={ + 'request': http_executor.to_raw_request() + } ) return NodeRunResult( @@ -45,7 +45,9 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: 'body': response, 'headers': response.headers }, - process_data=http_executor.to_raw_request() + process_data={ + 'request': http_executor.to_raw_request(), + } ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index 3c2b0cebfc4c86..9cc43031f3a05c 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -3,6 +3,7 @@ import requests.api as requests import httpx._api as httpx from requests import Response as RequestsResponse +from httpx import Request as HttpxRequest from yarl import URL from typing import Literal @@ -12,8 +13,8 @@ MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' class MockedHttp: - def requests_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], - url: str, **kwargs) -> RequestsResponse: + def requests_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str, + **kwargs) -> RequestsResponse: """ Mocked requests.request """ @@ -41,13 +42,15 @@ def requests_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATC response._content = resp return response - def httpx_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str, **kwargs) -> httpx.Response: """ Mocked httpx.request """ - response = httpx.Response() - response.url = str(URL(url) % kwargs.get('params', {})) + response = httpx.Response( + status_code=200, + request=HttpxRequest(method, url) + ) response.headers = kwargs.get('headers', {}) if url == 'http://404.com': @@ -67,7 +70,7 @@ def httpx_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', resp = b'OK' response.status_code = 200 - response.content = resp + response._content = resp return response @pytest.fixture diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 25c293d5635e52..6df8f6b6733de2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -2,7 +2,6 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.http_request.entities import HttpRequestNodeData from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -21,13 +20,16 @@ pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) -def test_get_param(setup_http_mock): +def test_get(setup_http_mock): node = HttpRequestNode(config={ 'id': '1', 'data': { 'title': 'http', 'desc': '', - 'variables': [], + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], 'method': 'get', 'url': 'http://example.com', 'authorization': { @@ -38,14 +40,170 @@ def test_get_param(setup_http_mock): 'header': 'api-key', } }, - 'headers': '', - 'params': '', + 'headers': 'X-Header:123', + 'params': 'A:b', 'body': None, } }, **BASIC_NODE_DATA) result = node.run(pool) - print(result) + data = result.process_data.get('request', '') - assert 1==2 \ No newline at end of file + assert '?A=b' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_template(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'get', + 'url': 'http://example.com/{{args1}}', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123\nX-Header2:{{args1}}', + 'params': 'A:b\nTemplate:{{args1}}', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'Template=2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + assert 'X-Header2: 2' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_json(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'json', + 'data': '{"a": "{{args1}}"}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert '{"a": "1"}' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_x_www_form_urlencoded(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'x-www-form-urlencoded', + 'data': 'a:{{args1}}\nb:{{args2}}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'a=1&b=2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_form_data(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'form-data', + 'data': 'a:{{args1}}\nb:{{args2}}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'form-data; name="a"' in data + assert '1' in data + assert 'form-data; name="b"' in data + assert '2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data From 513a8655b1009eec73f07c3f9390ab8ef2b60da7 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 21:31:39 +0800 Subject: [PATCH 126/160] test: tool --- api/core/tools/tool_manager.py | 9 ++- api/core/workflow/nodes/tool/tool_node.py | 11 +-- .../workflow/nodes/test_tool.py | 70 +++++++++++++++++++ 3 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/test_tool.py diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 52e1e71d8240e1..600b54f1c20ad9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -315,8 +315,9 @@ def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_c for parameter in parameters: # save tool parameter to tool entity memory - value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) - runtime_parameters[parameter.name] = value + if parameter.form == ToolParameter.ToolParameterForm.FORM: + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) + runtime_parameters[parameter.name] = value # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( @@ -325,7 +326,9 @@ def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_c provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, ) - runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + if runtime_parameters: + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 69a97fc206172b..c62e025e75fed1 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -29,7 +29,6 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: # get parameters parameters = self._generate_parameters(variable_pool, node_data) - # get tool runtime try: tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) @@ -41,7 +40,6 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) try: - # TODO: user_id messages = tool_runtime.invoke(self.user_id, parameters) except Exception as e: return NodeRunResult( @@ -68,7 +66,7 @@ def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeD return { k.variable: k.value if k.variable_type == 'static' else - variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else '' + variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else '' for k in node_data.tool_parameters } @@ -77,7 +75,12 @@ def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str Convert ToolInvokeMessages into tuple[plain_text, files] """ # transform message and handle file storage - messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages) + messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id='', + ) # extract plain text and files files = self._extract_tool_response_binary(messages) plain_text = self._extract_tool_response_text(messages) diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py new file mode 100644 index 00000000000000..72e0d6f8536ae5 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -0,0 +1,70 @@ +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.tool.tool_node import ToolNode +from models.workflow import WorkflowNodeExecutionStatus + +""" +class ToolEntity(BaseModel): + provider_id: str + provider_type: Literal['builtin', 'api'] + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_configurations: dict[str, ToolParameterValue] + +class ToolNodeData(BaseNodeData, ToolEntity): + class ToolInput(VariableSelector): + variable_type: Literal['selector', 'static'] + value: Optional[str] + + @validator('value') + def check_value(cls, value, values, **kwargs): + if values['variable_type'] == 'static' and value is None: + raise ValueError('value is required for static variable') + return value + + tool_parameters: list[ToolInput] + +""" + +def test_tool_invoke(): + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') + + node = ToolNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'title': 'a', + 'desc': 'a', + 'provider_id': 'maths', + 'provider_type': 'builtin', + 'provider_name': 'maths', + 'tool_name': 'eval_expression', + 'tool_label': 'eval_expression', + 'tool_configurations': {}, + 'tool_parameters': [ + { + 'variable': 'expression', + 'value_selector': ['1', '123', 'args1'], + 'variable_type': 'selector', + 'value': None + }, + ] + } + } + ) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert '2' in result.outputs['text'] + assert result.outputs['files'] == [] \ No newline at end of file From 2c2b9e738929da9ab06689e37123c5d645b3be87 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 21:52:49 +0800 Subject: [PATCH 127/160] test: template transform --- .../template_transform_node.py | 9 +++- .../workflow/nodes/__mock/code_executor.py | 4 ++ .../workflow/nodes/test_template_transform.py | 46 +++++++++++++++++++ .../workflow/nodes/test_tool.py | 25 ---------- 4 files changed, 58 insertions(+), 26 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/test_template_transform.py diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index c41f5d1030a6bc..15d4b2a6e7b81f 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -7,6 +7,7 @@ from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = 1000 class TemplateTransformNode(BaseNode): _node_data_cls = TemplateTransformNodeData @@ -48,7 +49,6 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: ) variables[variable] = value - # Run code try: result = CodeExecutor.execute_code( @@ -62,6 +62,13 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: status=WorkflowNodeExecutionStatus.FAILED, error=str(e) ) + + if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + return NodeRunResult( + inputs=variables, + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index a1c8eb71dc5f23..2eb987181fa93a 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -15,6 +15,10 @@ def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, return { "result": 3 } + elif language == 'jinja2': + return { + "result": "3" + } @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py new file mode 100644 index 00000000000000..4348995a055026 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -0,0 +1,46 @@ +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = '''{{args2}}''' + node = TemplateTransformNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'template': code, + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=3) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['output'] == '3' diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 72e0d6f8536ae5..66139563e29429 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,34 +1,9 @@ -import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.tool.tool_node import ToolNode from models.workflow import WorkflowNodeExecutionStatus -""" -class ToolEntity(BaseModel): - provider_id: str - provider_type: Literal['builtin', 'api'] - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, ToolParameterValue] - -class ToolNodeData(BaseNodeData, ToolEntity): - class ToolInput(VariableSelector): - variable_type: Literal['selector', 'static'] - value: Optional[str] - - @validator('value') - def check_value(cls, value, values, **kwargs): - if values['variable_type'] == 'static' and value is None: - raise ValueError('value is required for static variable') - return value - - tool_parameters: list[ToolInput] - -""" - def test_tool_invoke(): pool = VariablePool(system_variables={}, user_inputs={}) pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') From b102562614d6c57db7b5b07efa4c352822b862f5 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 21:58:54 +0800 Subject: [PATCH 128/160] fix: forward-ref --- api/core/workflow/nodes/code/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 0e2b3c99bfb4a8..ec3e3fe530d286 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -12,7 +12,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] - children: Optional[dict[str, 'CodeNodeData.Output']] + children: Optional[dict[str, 'Output']] variables: list[VariableSelector] answer: str From a420953385f3ebd7fdd08996f5976395b5e8a99b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 22:12:13 +0800 Subject: [PATCH 129/160] feat: docker-compose --- docker/docker-compose.middleware.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index afdabd078a4bbe..60604aeaecec79 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -11,6 +11,9 @@ services: POSTGRES_DB: dify # postgres data directory PGDATA: /var/lib/postgresql/data/pgdata + # The sandbox service endpoint. + CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" + CODE_EXECUTION_API_KEY: dify-sandbox volumes: - ./volumes/db/data:/var/lib/postgresql/data ports: @@ -50,6 +53,16 @@ services: AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' ports: - "8080:8080" + + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:latest + restart: always + environment: + # The DifySandbox configurations + API_KEY: dify-sandbox + ports: + - "8194:8194" # Qdrant vector store. # uncomment to use qdrant as vector store. From 951aaf5161d6c812745b6953ade0b22ff72cf630 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 22:14:28 +0800 Subject: [PATCH 130/160] feat: sandbox --- docker/docker-compose.middleware.yaml | 3 --- docker/docker-compose.yaml | 13 +++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 60604aeaecec79..8fba59c3154415 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -11,9 +11,6 @@ services: POSTGRES_DB: dify # postgres data directory PGDATA: /var/lib/postgresql/data/pgdata - # The sandbox service endpoint. - CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" - CODE_EXECUTION_API_KEY: dify-sandbox volumes: - ./volumes/db/data:/var/lib/postgresql/data ports: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d627bb38481173..ca6b6cbf1a17d2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -122,6 +122,9 @@ services: SENTRY_TRACES_SAMPLE_RATE: 1.0 # The sample rate for Sentry profiles. Default: `1.0` SENTRY_PROFILES_SAMPLE_RATE: 1.0 + # The sandbox service endpoint. + CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" + CODE_EXECUTION_API_KEY: dify-sandbox depends_on: - db - redis @@ -286,6 +289,16 @@ services: # ports: # - "8080:8080" + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:latest + restart: always + environment: + # The DifySandbox configurations + API_KEY: dify-sandbox + ports: + - "8194:8194" + # Qdrant vector store. # uncomment to use qdrant as vector store. # (if uncommented, you need to comment out the weaviate service above, From 92c1da8dbeb92310bb07c7507aee2420c4cd179e Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 16:25:07 +0800 Subject: [PATCH 131/160] fix: remove answer --- api/core/workflow/nodes/code/entities.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index ec3e3fe530d286..d4d76c45f9f879 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -15,7 +15,6 @@ class Output(BaseModel): children: Optional[dict[str, 'Output']] variables: list[VariableSelector] - answer: str code_language: Literal['python3', 'javascript'] code: str outputs: dict[str, Output] From e8751bebfa1b8b05ae6cf1274a4457075f51de07 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 19:15:11 +0800 Subject: [PATCH 132/160] fix single step run error --- api/services/workflow_service.py | 64 +++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2c9c07106cec94..55f2526fbfc827 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -270,28 +270,48 @@ def run_draft_workflow_node(self, app_model: App, return workflow_node_execution - # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.utcnow(), - finished_at=datetime.utcnow() - ) + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, + process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, + outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, + execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) + if node_run_result.metadata else None), + status=WorkflowNodeExecutionStatus.SUCCEEDED.value, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + else: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + status=node_run_result.status.value, + error=node_run_result.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) db.session.add(workflow_node_execution) db.session.commit() From d88ac6c238412984e37967e51219e553f12bc254 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 22:12:03 +0800 Subject: [PATCH 133/160] add llm node --- api/core/app/apps/base_app_runner.py | 31 +- .../easy_ui_based_generate_task_pipeline.py | 83 +--- api/core/model_manager.py | 4 +- api/core/prompt/advanced_prompt_transform.py | 51 ++- .../entities}/__init__.py | 0 .../entities/advanced_prompt_entities.py | 42 ++ api/core/prompt/prompt_transform.py | 19 +- api/core/prompt/simple_prompt_transform.py | 11 + api/core/prompt/utils/prompt_message_util.py | 85 ++++ api/core/workflow/entities/node_entities.py | 2 +- api/core/workflow/nodes/answer/__init__.py | 0 .../answer_node.py} | 8 +- .../{direct_answer => answer}/entities.py | 4 +- api/core/workflow/nodes/llm/entities.py | 45 ++- api/core/workflow/nodes/llm/llm_node.py | 370 +++++++++++++++++- api/core/workflow/workflow_engine_manager.py | 47 +-- .../prompt/test_advanced_prompt_transform.py | 77 ++-- 17 files changed, 697 insertions(+), 182 deletions(-) rename api/core/{workflow/nodes/direct_answer => prompt/entities}/__init__.py (100%) create mode 100644 api/core/prompt/entities/advanced_prompt_entities.py create mode 100644 api/core/prompt/utils/prompt_message_util.py create mode 100644 api/core/workflow/nodes/answer/__init__.py rename api/core/workflow/nodes/{direct_answer/direct_answer_node.py => answer/answer_node.py} (91%) rename api/core/workflow/nodes/{direct_answer => answer}/entities.py (75%) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e7ce7f25ef51a4..868e9e724f4081 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -23,7 +23,8 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation @@ -155,13 +156,39 @@ def organize_prompt_messages(self, app_record: App, model_config=model_config ) else: + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False + ) + ) + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + prompt_template = CompletionModelPromptTemplate( + text=advanced_completion_prompt_template.prompt + ) + + memory_config.role_prefix = MemoryConfig.RolePrefix( + user=advanced_completion_prompt_template.role_prefix.user, + assistant=advanced_completion_prompt_template.role_prefix.assistant + ) + else: + prompt_template = [] + for message in prompt_template_entity.advanced_chat_prompt_template.messages: + prompt_template.append(ChatModelMessage( + text=message.text, + role=message.role + )) + prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query if query else '', files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) diff --git a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 856bfb623d0e22..412029b02491f8 100644 --- a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -30,17 +30,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.moderation.output_moderation import ModerationRule, OutputModeration -from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created @@ -438,7 +433,10 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() - self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) + self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + self._model_config.mode, + self._task_state.llm_result.prompt_messages + ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -582,77 +580,6 @@ def _yield_response(self, response: dict) -> str: """ return "data: " + json.dumps(response) + "\n\n" - def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]: - """ - Prompt messages to prompt for saving. - :param prompt_messages: prompt messages - :return: - """ - prompts = [] - if self._model_config.mode == ModelMode.CHAT.value: - for prompt_message in prompt_messages: - if prompt_message.role == PromptMessageRole.USER: - role = 'user' - elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' - elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' - else: - continue - - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - prompts.append({ - "role": role, - "text": text, - "files": files - }) - else: - prompt_message = prompt_messages[0] - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - params = { - "role": 'user', - "text": text, - } - - if files: - params['files'] = files - - prompts.append(params) - - return prompts - def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index aa16cf866f9327..8c0633992767dc 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -24,11 +24,11 @@ class ModelInstance: """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: - self._provider_model_bundle = provider_model_bundle + self.provider_model_bundle = provider_model_bundle self.model = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) - self.model_type_instance = self._provider_model_bundle.model_type_instance + self.model_type_instance = self.provider_model_bundle.model_type_instance def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 48b0d8ba021e03..60c77e943b3322 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,5 @@ -from typing import Optional +from typing import Optional, Union -from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -12,6 +11,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform): Advanced Prompt Transform for Workflow LLM Node. """ - def get_prompt(self, prompt_template_entity: PromptTemplateEntity, + def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], inputs: dict, query: str, files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: prompt_messages = [] @@ -34,21 +35,23 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: prompt_messages = self._get_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) @@ -56,17 +59,18 @@ def get_prompt(self, prompt_template_entity: PromptTemplateEntity, return prompt_messages def _get_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, + prompt_template: CompletionModelPromptTemplate, inputs: dict, query: Optional[str], files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt + raw_prompt = prompt_template.text prompt_messages = [] @@ -75,15 +79,17 @@ def _get_completion_model_prompt_messages(self, prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - prompt_inputs = self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config - ) + if memory and memory_config: + role_prefix = memory_config.role_prefix + prompt_inputs = self._set_histories_variable( + memory=memory, + memory_config=memory_config, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) @@ -104,17 +110,18 @@ def _get_completion_model_prompt_messages(self, return prompt_messages def _get_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, + prompt_template: list[ChatModelMessage], inputs: dict, query: Optional[str], files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages + raw_prompt_list = prompt_template prompt_messages = [] @@ -137,8 +144,8 @@ def _get_chat_model_prompt_messages(self, elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) - if memory: - prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) + if memory and memory_config: + prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) if files: prompt_message_contents = [TextPromptMessageContent(data=query)] @@ -195,8 +202,9 @@ def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, return prompt_inputs def _set_histories_variable(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, + role_prefix: MemoryConfig.RolePrefix, prompt_template: PromptTemplateParser, prompt_inputs: dict, model_config: ModelConfigWithCredentialsEntity) -> dict: @@ -213,6 +221,7 @@ def _set_histories_variable(self, memory: TokenBufferMemory, histories = self._get_history_messages_from_memory( memory=memory, + memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, ai_prefix=role_prefix.assistant diff --git a/api/core/workflow/nodes/direct_answer/__init__.py b/api/core/prompt/entities/__init__.py similarity index 100% rename from api/core/workflow/nodes/direct_answer/__init__.py rename to api/core/prompt/entities/__init__.py diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py new file mode 100644 index 00000000000000..97ac2e3e2a8651 --- /dev/null +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -0,0 +1,42 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """ + Chat Message. + """ + text: str + role: PromptMessageRole + + +class CompletionModelPromptTemplate(BaseModel): + """ + Completion Model Prompt Template. + """ + text: str + + +class MemoryConfig(BaseModel): + """ + Memory Config. + """ + class RolePrefix(BaseModel): + """ + Role Prefix. + """ + user: str + assistant: str + + class WindowConfig(BaseModel): + """ + Window Config. + """ + enabled: bool + size: Optional[int] = None + + role_prefix: Optional[RolePrefix] = None + window: WindowConfig diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 02e91d91128629..9bf2ae090f7686 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -5,19 +5,22 @@ from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -44,6 +47,7 @@ def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_conf return rest_tokens def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, max_token_limit: int, human_prefix: Optional[str] = None, ai_prefix: Optional[str] = None) -> str: @@ -58,13 +62,22 @@ def _get_history_messages_from_memory(self, memory: TokenBufferMemory, if ai_prefix: kwargs['ai_prefix'] = ai_prefix + if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: + kwargs['message_limit'] = memory_config.window.size + return memory.get_history_prompt_text( **kwargs ) def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, max_token_limit: int) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( - max_token_limit=max_token_limit + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if (memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0) + else 10 ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index ca0efb200c15b1..613716c2cf3b6c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -13,6 +13,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode @@ -182,6 +183,11 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, if memory: prompt_messages = self._append_chat_histories( memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), prompt_messages=prompt_messages, model_config=model_config ) @@ -220,6 +226,11 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), max_token_limit=rest_tokens, ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py new file mode 100644 index 00000000000000..5fceeb3595c9d7 --- /dev/null +++ b/api/core/prompt/utils/prompt_message_util.py @@ -0,0 +1,85 @@ +from typing import cast + +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) +from core.prompt.simple_prompt_transform import ModelMode + + +class PromptMessageUtil: + @staticmethod + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + """ + Prompt messages to prompt for saving. + :param model_mode: model mode + :param prompt_messages: prompt messages + :return: + """ + prompts = [] + if model_mode == ModelMode.CHAT.value: + for prompt_message in prompt_messages: + if prompt_message.role == PromptMessageRole.USER: + role = 'user' + elif prompt_message.role == PromptMessageRole.ASSISTANT: + role = 'assistant' + elif prompt_message.role == PromptMessageRole.SYSTEM: + role = 'system' + else: + continue + + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + prompts.append({ + "role": role, + "text": text, + "files": files + }) + else: + prompt_message = prompt_messages[0] + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + params = { + "role": 'user', + "text": text, + } + + if files: + params['files'] = files + + prompts.append(params) + + return prompts diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 263172da31b88e..befabfb3b4e333 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -12,7 +12,7 @@ class NodeType(Enum): """ START = 'start' END = 'end' - DIRECT_ANSWER = 'direct-answer' + ANSWER = 'answer' LLM = 'llm' KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' IF_ELSE = 'if-else' diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/answer/answer_node.py similarity index 91% rename from api/core/workflow/nodes/direct_answer/direct_answer_node.py rename to api/core/workflow/nodes/answer/answer_node.py index 22ef2ed53b7e61..381ada1a1e52d0 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -5,14 +5,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.nodes.answer.entities import AnswerNodeData from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData from models.workflow import WorkflowNodeExecutionStatus -class DirectAnswerNode(BaseNode): - _node_data_cls = DirectAnswerNodeData - node_type = NodeType.DIRECT_ANSWER +class AnswerNode(BaseNode): + _node_data_cls = AnswerNodeData + node_type = NodeType.ANSWER def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ diff --git a/api/core/workflow/nodes/direct_answer/entities.py b/api/core/workflow/nodes/answer/entities.py similarity index 75% rename from api/core/workflow/nodes/direct_answer/entities.py rename to api/core/workflow/nodes/answer/entities.py index e7c11e3c4d1d2e..7c6fed3e4ea6f4 100644 --- a/api/core/workflow/nodes/direct_answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -2,9 +2,9 @@ from core.workflow.entities.variable_entities import VariableSelector -class DirectAnswerNodeData(BaseNodeData): +class AnswerNodeData(BaseNodeData): """ - DirectAnswer Node Data. + Answer Node Data. """ variables: list[VariableSelector] = [] answer: str diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bd499543d903aa..67163c93cd2b19 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,8 +1,51 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class ContextConfig(BaseModel): + """ + Context Config. + """ + enabled: bool + variable_selector: Optional[list[str]] = None + + +class VisionConfig(BaseModel): + """ + Vision Config. + """ + class Configs(BaseModel): + """ + Configs. + """ + detail: Literal['low', 'high'] + + enabled: bool + configs: Optional[Configs] = None class LLMNodeData(BaseNodeData): """ LLM Node Data. """ - pass + model: ModelConfig + variables: list[VariableSelector] = [] + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] + memory: Optional[MemoryConfig] = None + context: ContextConfig + vision: VisionConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 41e28937ac7d2a..d1050a5f5b366d 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,10 +1,27 @@ +from collections.abc import Generator from typing import Optional, cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData +from extensions.ext_database import db +from models.model import Conversation +from models.workflow import WorkflowNodeExecutionStatus class LLMNode(BaseNode): @@ -20,7 +37,341 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - pass + node_inputs = None + process_data = None + + try: + # fetch variables and fetch values from variable pool + inputs = self._fetch_inputs(node_data, variable_pool) + + node_inputs = { + **inputs + } + + # fetch files + files: list[FileObj] = self._fetch_files(node_data, variable_pool) + + if files: + node_inputs['#files#'] = [{ + 'type': file.type.value, + 'transfer_method': file.transfer_method.value, + 'url': file.url, + 'upload_file_id': file.upload_file_id, + } for file in files] + + # fetch context value + context = self._fetch_context(node_data, variable_pool) + + if context: + node_inputs['#context#'] = context + + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + + # fetch memory + memory = self._fetch_memory(node_data, variable_pool, model_instance) + + # fetch prompt messages + prompt_messages, stop = self._fetch_prompt_messages( + node_data=node_data, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ) + } + + # handle invoke result + result_text, usage = self._invoke_llm( + node_data=node_data, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop + ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) + + outputs = { + 'text': result_text, + 'usage': jsonable_encoder(usage) + } + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) + + def _invoke_llm(self, node_data: LLMNodeData, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str]) -> tuple[str, LLMUsage]: + """ + Invoke large language model + :param node_data: node data + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data.model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + return self._handle_invoke_result( + invoke_result=invoke_result + ) + + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + model = None + prompt_messages = [] + full_text = '' + usage = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + self.publish_text_chunk(text=text) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + return full_text, usage + + def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + inputs = {} + for variable_selector in node_data.variables: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.value_selector} not found') + + inputs[variable_selector.variable] = variable_value + + return inputs + + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]: + """ + Fetch files + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.vision.enabled: + return [] + + files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + if not files: + return [] + + return files + + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + """ + Fetch context + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.context.enabled: + return None + + context_value = variable_pool.get_variable_value(node_data.context.variable_selector) + if context_value: + if isinstance(context_value, str): + return context_value + elif isinstance(context_value, list): + context_str = '' + for item in context_value: + if 'content' not in item: + raise ValueError(f'Invalid context structure: {item}') + + context_str += item['content'] + '\n' + + return context_str.strip() + + return None + + def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.model.name + provider_name = node_data.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory(self, node_data: LLMNodeData, + variable_pool: VariablePool, + model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + """ + Fetch memory + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.memory: + return None + + # get conversation id + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) + if conversation_id is None: + return None + + # get conversation + conversation = db.session.query(Conversation).filter( + Conversation.tenant_id == self.tenant_id, + Conversation.app_id == self.app_id, + Conversation.id == conversation_id + ).first() + + if not conversation: + return None + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + return memory + + def _fetch_prompt_messages(self, node_data: LLMNodeData, + inputs: dict[str, str], + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Fetch prompt messages + :param node_data: node data + :param inputs: inputs + :param files: files + :param context: context + :param memory: memory + :param model_config: model config + :return: + """ + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=node_data.prompt_template, + inputs=inputs, + query='', + files=files, + context=context, + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + stop = model_config.stop + + return prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: @@ -29,9 +380,20 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) :param node_data: node data :return: """ - # TODO extract variable selector to variable mapping for single step debugging - return {} + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + if node_data.context.enabled: + variable_mapping['#context#'] = node_data.context.variable_selector + + if node_data.vision.enabled: + variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + return variable_mapping @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 17225c19ea0bb9..49b9d4ac4d7b4c 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -7,9 +7,9 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode @@ -24,13 +24,12 @@ from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, - WorkflowType, ) node_classes = { NodeType.START: StartNode, NodeType.END: EndNode, - NodeType.DIRECT_ANSWER: DirectAnswerNode, + NodeType.ANSWER: AnswerNode, NodeType.LLM: LLMNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.IF_ELSE: IfElseNode, @@ -156,7 +155,7 @@ def run_workflow(self, workflow: Workflow, callbacks=callbacks ) - if next_node.node_type == NodeType.END: + if next_node.node_type in [NodeType.END, NodeType.ANSWER]: break predecessor_node = next_node @@ -402,10 +401,16 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, # add to workflow_nodes_and_results workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) + try: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool + ) + except Exception as e: + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed @@ -420,9 +425,6 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - # set end node output if in chat - self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) - workflow_nodes_and_result.result = node_run_result # node run success @@ -453,29 +455,6 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, db.session.close() - def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, - node: BaseNode, - node_run_result: NodeRunResult) -> None: - """ - Set end node output if in chat - :param workflow_run_state: workflow run state - :param node: current node - :param node_run_result: node run result - :return: - """ - if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END: - workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] - if workflow_nodes_and_result_before_end: - if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: - if not node_run_result.outputs: - node_run_result.outputs = {} - - node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') - elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER: - if not node_run_result.outputs: - node_run_result.outputs = {} - - node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer') def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 4357c6405c8a3d..5c08b9f168ad20 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,12 +2,12 @@ import pytest -from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ - ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity +from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages(): model_config_mock.model = 'gpt-3.5-turbo-instruct' prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_entity = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( - prompt=prompt_template, - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) + prompt_template_config = CompletionModelPromptTemplate( + text=prompt_template + ) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix( + user="Human", + assistant="Assistant" + ), + window=MemoryConfig.WindowConfig( + enabled=False ) ) + inputs = { "name": "John" } @@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages(): prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template_config, inputs=inputs, query=None, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config_mock ) @@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages(): def test__get_chat_model_prompt_messages(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, memory_config, messages, inputs, context = get_chat_model_args files = [] query = "Hi2." @@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config_mock ) @@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) assert prompt_messages[5].content == query def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, _, messages, inputs, context = get_chat_model_args files = [] prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=None, files=files, context=context, + memory_config=None, memory=None, model_config=model_config_mock ) @@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, _, messages, inputs, context = get_chat_model_args files = [ FileObj( @@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=None, files=files, context=context, + memory_config=None, memory=None, model_config=model_config_mock ) @@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 @@ -173,22 +181,31 @@ def get_chat_model_args(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' - prompt_template_entity = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( - messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ] + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False ) ) + prompt_messages = [ + ChatModelMessage( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM + ), + ChatModelMessage( + text="Hi.", + role=PromptMessageRole.USER + ), + ChatModelMessage( + text="Hello!", + role=PromptMessageRole.ASSISTANT + ) + ] + inputs = { "name": "John" } context = "I am superman." - return model_config_mock, prompt_template_entity, inputs, context + return model_config_mock, memory_config, prompt_messages, inputs, context From 2182533af830181f6b88b6c2fa89fa6ed44a91e4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 22:41:59 +0800 Subject: [PATCH 134/160] feat: javascript code --- api/.env.example | 2 +- .../helper/code_executor/code_executor.py | 8 ++- .../code_executor/javascript_transformer.py | 54 ++++++++++++++++++- api/core/workflow/nodes/code/code_node.py | 17 ++++-- api/core/workflow/nodes/code/entities.py | 2 +- 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/api/.env.example b/api/.env.example index 4a3b1d65afdfc0..c0942412ab948f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -135,4 +135,4 @@ BATCH_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT= -CODE_EXECUTINO_API_KEY= +CODE_EXECUTION_API_KEY= diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 21a8ca5f9f1ba7..adfdf6cc69f57f 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from httpx import post from pydantic import BaseModel from yarl import URL +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer @@ -39,17 +40,20 @@ def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code template_transformer = PythonTemplateTransformer elif language == 'jinja2': template_transformer = Jinja2TemplateTransformer + elif language == 'javascript': + template_transformer = NodeJsTemplateTransformer else: raise CodeExecutionException('Unsupported language') runner = template_transformer.transform_caller(code, inputs) - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' headers = { 'X-Api-Key': CODE_EXECUTION_API_KEY } data = { - 'language': language if language != 'jinja2' else 'python3', + 'language': 'python3' if language == 'jinja2' else + 'nodejs' if language == 'javascript' else + 'python3' if language == 'python3' else None, 'code': runner, } diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py index f87f5c14cbbd7d..cc6ad16c66d833 100644 --- a/api/core/helper/code_executor/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -1 +1,53 @@ -# TODO \ No newline at end of file +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +NODEJS_RUNNER = """// declare main function here +{{code}} + +// execute main function, and return the result +// inputs is a dict, unstructured inputs +output = main({{inputs}}) + +// convert output to json and print +output = JSON.stringify(output) + +result = `<>${output}<>` + +console.log(result) +""" + + +class NodeJsTemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = NODEJS_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + return json.loads(result) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2c11e5ba00b9ac..5dfe398711528f 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -15,6 +15,16 @@ MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 +JAVASCRIPT_DEFAULT_CODE = """function main({args1, args2}) { + return { + result: args1 + args2 + } +}""" + +PYTHON_DEFAULT_CODE = """def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + }""" class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -42,9 +52,7 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } ], "code_language": "javascript", - "code": "async function main(arg1, arg2) {\n return new Promise((resolve, reject) => {" - "\n if (true) {\n resolve({\n \"result\": arg1 + arg2" - "\n });\n } else {\n reject(\"e\");\n }\n });\n}", + "code": JAVASCRIPT_DEFAULT_CODE, "outputs": [ { "variable": "result", @@ -68,8 +76,7 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } ], "code_language": "python3", - "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " - "+ arg2\n }", + "code": PYTHON_DEFAULT_CODE, "outputs": [ { "variable": "result", diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index d4d76c45f9f879..97e178f5df9112 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -17,4 +17,4 @@ class Output(BaseModel): variables: list[VariableSelector] code_language: Literal['python3', 'javascript'] code: str - outputs: dict[str, Output] + outputs: dict[str, Output] \ No newline at end of file From e6572ef2d76b3cce21ca6a41be1c4c824a63a1d9 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 22:42:28 +0800 Subject: [PATCH 135/160] fix: linter --- api/core/helper/code_executor/code_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index adfdf6cc69f57f..9d74edee0e5248 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,8 +4,8 @@ from httpx import post from pydantic import BaseModel from yarl import URL -from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer From e4794e309a94b992e4504ea93f76196cd04127ad Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 23:08:14 +0800 Subject: [PATCH 136/160] add llm node test --- .../workflow/nodes/__init__.py | 0 .../workflow/nodes/test_llm.py | 132 ++++++++++++++++++ .../workflow/nodes/test_template_transform.py | 4 +- .../core/workflow/nodes/__init__.py | 0 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/__init__.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_llm.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/__init__.py diff --git a/api/tests/integration_tests/workflow/nodes/__init__.py b/api/tests/integration_tests/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py new file mode 100644 index 00000000000000..18fba566bf7f71 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -0,0 +1,132 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderModelBundle, ProviderConfiguration +from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, CustomProviderConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.llm.llm_node import LLMNode +from extensions.ext_database import db +from models.provider import ProviderType +from models.workflow import WorkflowNodeExecutionStatus + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_execute_llm(setup_openai_mock): + node = LLMNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'llm', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5.turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'variables': [ + { + 'variable': 'weather', + 'value_selector': ['abc', 'output'], + }, + { + 'variable': 'query', + 'value_selector': ['sys', 'query'] + } + ], + 'prompt_template': [ + { + 'role': 'system', + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.' + }, + { + 'role': 'user', + 'text': '{{query}}' + } + ], + 'memory': { + 'window': { + 'enabled': True, + 'size': 2 + } + }, + 'context': { + 'enabled': False + }, + 'vision': { + 'enabled': False + } + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather today?', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION: 'abababa' + }, user_inputs={}) + pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + + credentials = { + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + + provider_instance = ModelProviderFactory().get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ) + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_config = ModelConfigWithCredentialsEntity( + model='gpt-3.5-turbo', + provider='openai', + mode='chat', + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), + provider_model_bundle=provider_model_bundle + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['text'] is not None + assert result.outputs['usage']['total_tokens'] > 0 diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 4348995a055026..36cf0a070aa855 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,7 +1,7 @@ import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -14,7 +14,7 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.END_USER, config={ 'id': '1', 'data': { diff --git a/api/tests/unit_tests/core/workflow/nodes/__init__.py b/api/tests/unit_tests/core/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 From da3e1e9d14a2b6aa102709898d0469a5962bdb9d Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 00:08:13 +0800 Subject: [PATCH 137/160] add deduct quota for llm node --- api/core/workflow/nodes/llm/llm_node.py | 56 ++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index d1050a5f5b366d..9285bbe74e87a1 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -3,6 +3,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -21,6 +22,7 @@ from core.workflow.nodes.llm.entities import LLMNodeData from extensions.ext_database import db from models.model import Conversation +from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus @@ -144,10 +146,15 @@ def _invoke_llm(self, node_data: LLMNodeData, ) # handle invoke result - return self._handle_invoke_result( + text, usage = self._handle_invoke_result( invoke_result=invoke_result ) + # deduct quota + self._deduct_llm_quota(model_instance=model_instance, usage=usage) + + return text, usage + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ Handle invoke result @@ -373,6 +380,53 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, return prompt_messages, stop + def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if 'gpt-4' in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ From 2b4b6817a3d082c8a4421918b2aef672771bad2f Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 14:55:56 +0800 Subject: [PATCH 138/160] record inputs and process data when node failed --- .../workflow_event_trigger_callback.py | 6 +++++- .../workflow_event_trigger_callback.py | 6 +++++- api/core/app/entities/queue_entities.py | 3 +++ .../callbacks/base_workflow_callback.py | 4 +++- api/core/workflow/workflow_engine_manager.py | 4 +++- api/models/workflow.py | 18 +++++++++--------- .../workflow/nodes/test_llm.py | 2 +- 7 files changed, 29 insertions(+), 14 deletions(-) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index d9c8a2c96ddb6c..b4a6a9602f6c51 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -96,7 +96,9 @@ def on_workflow_node_execute_succeeded(self, node_id: str, def on_workflow_node_execute_failed(self, node_id: str, node_type: NodeType, node_data: BaseNodeData, - error: str) -> None: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: """ Workflow node execute failed """ @@ -105,6 +107,8 @@ def on_workflow_node_execute_failed(self, node_id: str, node_id=node_id, node_type=node_type, node_data=node_data, + inputs=inputs, + process_data=process_data, error=error ), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 318466711a3252..ea7eb5688cd754 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -96,7 +96,9 @@ def on_workflow_node_execute_succeeded(self, node_id: str, def on_workflow_node_execute_failed(self, node_id: str, node_type: NodeType, node_data: BaseNodeData, - error: str) -> None: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: """ Workflow node execute failed """ @@ -105,6 +107,8 @@ def on_workflow_node_execute_failed(self, node_id: str, node_id=node_id, node_type=node_type, node_data=node_data, + inputs=inputs, + process_data=process_data, error=error ), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 0ea7744b582943..153607e1b4473e 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -158,6 +158,9 @@ class QueueNodeFailedEvent(AppQueueEvent): node_type: NodeType node_data: BaseNodeData + inputs: Optional[dict] = None + process_data: Optional[dict] = None + error: str diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index cf2915ed8646d4..9594fa20372064 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -55,7 +55,9 @@ def on_workflow_node_execute_succeeded(self, node_id: str, def on_workflow_node_execute_failed(self, node_id: str, node_type: NodeType, node_data: BaseNodeData, - error: str) -> None: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: """ Workflow node execute failed """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 49b9d4ac4d7b4c..ebc753537e172a 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -420,7 +420,9 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node_id=node.node_id, node_type=node.node_type, node_data=node.node_data, - error=node_run_result.error + error=node_run_result.error, + inputs=node_run_result.inputs, + process_data=node_run_result.process_data, ) raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") diff --git a/api/models/workflow.py b/api/models/workflow.py index 5a3cdcf83c5570..9c5b2a0b8f3135 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -123,11 +123,11 @@ def updated_by_account(self): @property def graph_dict(self): - return self.graph if not self.graph else json.loads(self.graph) + return json.loads(self.graph) if self.graph else None @property def features_dict(self): - return self.features if not self.features else json.loads(self.features) + return json.loads(self.features) if self.features else None def user_input_form(self) -> list: # get start node from graph @@ -270,15 +270,15 @@ def created_by_end_user(self): @property def graph_dict(self): - return self.graph if not self.graph else json.loads(self.graph) + return json.loads(self.graph) if self.graph else None @property def inputs_dict(self): - return self.inputs if not self.inputs else json.loads(self.inputs) + return json.loads(self.inputs) if self.inputs else None @property def outputs_dict(self): - return self.outputs if not self.outputs else json.loads(self.outputs) + return json.loads(self.outputs) if self.outputs else None class WorkflowNodeExecutionTriggeredFrom(Enum): @@ -419,19 +419,19 @@ def created_by_end_user(self): @property def inputs_dict(self): - return self.inputs if not self.inputs else json.loads(self.inputs) + return json.loads(self.inputs) if self.inputs else None @property def outputs_dict(self): - return self.outputs if not self.outputs else json.loads(self.outputs) + return json.loads(self.outputs) if self.outputs else None @property def process_data_dict(self): - return self.process_data if not self.process_data else json.loads(self.process_data) + return json.loads(self.process_data) if self.process_data else None @property def execution_metadata_dict(self): - return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata) + return json.loads(self.execution_metadata) if self.execution_metadata else None class WorkflowAppLogCreatedFrom(Enum): diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 18fba566bf7f71..999ebf77342601 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -36,7 +36,7 @@ def test_execute_llm(setup_openai_mock): 'type': 'llm', 'model': { 'provider': 'openai', - 'name': 'gpt-3.5.turbo', + 'name': 'gpt-3.5-turbo', 'mode': 'chat', 'completion_params': {} }, From 5213b0aade7efd50e1df43b055822db49bbbc71c Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 15:01:02 +0800 Subject: [PATCH 139/160] add sequence_number for workflow_started event --- api/core/app/apps/advanced_chat/generate_task_pipeline.py | 1 + api/core/app/apps/workflow/generate_task_pipeline.py | 1 + 2 files changed, 2 insertions(+) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index d5d3feded0328d..e8463e59d3b5d5 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -226,6 +226,7 @@ def _process_stream_response(self) -> Generator: 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, + 'sequence_number': workflow_run.sequence_number, 'created_at': int(workflow_run.created_at.timestamp()) } } diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 7a244151f2ddb1..cd1ea4c81eaf79 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -195,6 +195,7 @@ def _process_stream_response(self) -> Generator: 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, + 'sequence_number': workflow_run.sequence_number, 'created_at': int(workflow_run.created_at.timestamp()) } } From 7e53625eae2fd41ae739e3c7e121555f7a846526 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 15:08:15 +0800 Subject: [PATCH 140/160] fix value type --- api/core/workflow/entities/variable_pool.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 3868041a8f28ea..7a5f58d808d1b6 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -13,7 +13,10 @@ class ValueType(Enum): STRING = "string" NUMBER = "number" OBJECT = "object" - ARRAY = "array" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" FILE = "file" @@ -78,7 +81,10 @@ def get_variable_value(self, variable_selector: list[str], elif target_value_type == ValueType.OBJECT: if not isinstance(value, dict): raise ValueError('Invalid value type: object') - elif target_value_type == ValueType.ARRAY: + elif target_value_type in [ValueType.ARRAY_STRING, + ValueType.ARRAY_NUMBER, + ValueType.ARRAY_OBJECT, + ValueType.ARRAY_FILE]: if not isinstance(value, list): raise ValueError('Invalid value type: array') From 735b55e61b0751cf5ab75974b0f146474c9c575a Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 17:10:51 +0800 Subject: [PATCH 141/160] add if-else node --- api/core/workflow/entities/variable_pool.py | 2 +- api/core/workflow/nodes/if_else/entities.py | 26 ++ .../workflow/nodes/if_else/if_else_node.py | 395 +++++++++++++++++- .../core/workflow/nodes/if_else_node.py | 193 +++++++++ 4 files changed, 614 insertions(+), 2 deletions(-) create mode 100644 api/core/workflow/nodes/if_else/entities.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/if_else_node.py diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 7a5f58d808d1b6..ff96bc3bac0276 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -86,6 +86,6 @@ def get_variable_value(self, variable_selector: list[str], ValueType.ARRAY_OBJECT, ValueType.ARRAY_FILE]: if not isinstance(value, list): - raise ValueError('Invalid value type: array') + raise ValueError(f'Invalid value type: {target_value_type.value}') return value diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py new file mode 100644 index 00000000000000..68d51c93bee1ac --- /dev/null +++ b/api/core/workflow/nodes/if_else/entities.py @@ -0,0 +1,26 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class IfElseNodeData(BaseNodeData): + """ + Answer Node Data. + """ + class Condition(BaseModel): + """ + Condition entity + """ + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + # for number + "=", "≠", ">", "<", "≥", "≤", "null", "not null" + ] + value: Optional[str] = None + + logical_operator: Literal["and", "or"] = "and" + conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 98a5c85db2df32..9cb084b116dc52 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,398 @@ +from typing import Optional, cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.if_else.entities import IfElseNodeData +from models.workflow import WorkflowNodeExecutionStatus class IfElseNode(BaseNode): - pass + _node_data_cls = IfElseNodeData + node_type = NodeType.IF_ELSE + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + node_inputs = { + "conditions": [] + } + + process_datas = { + "condition_results": [] + } + + try: + logical_operator = node_data.logical_operator + input_conditions = [] + for condition in node_data.conditions: + actual_value = variable_pool.get_variable_value( + variable_selector=condition.variable_selector + ) + + expected_value = condition.value + + input_conditions.append({ + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator + }) + + node_inputs["conditions"] = input_conditions + + for input_condition in input_conditions: + actual_value = input_condition["actual_value"] + expected_value = input_condition["expected_value"] + comparison_operator = input_condition["comparison_operator"] + + if comparison_operator == "contains": + compare_result = self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + compare_result = self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + compare_result = self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + compare_result = self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + compare_result = self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + compare_result = self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + compare_result = self._assert_empty(actual_value) + elif comparison_operator == "not empty": + compare_result = self._assert_not_empty(actual_value) + elif comparison_operator == "=": + compare_result = self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + compare_result = self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + compare_result = self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + compare_result = self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + compare_result = self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + compare_result = self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + compare_result = self._assert_null(actual_value) + elif comparison_operator == "not null": + compare_result = self._assert_not_null(actual_value) + else: + continue + + process_datas["condition_results"].append({ + **input_condition, + "result": compare_result + }) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=node_inputs, + process_datas=process_datas, + error=str(e) + ) + + if logical_operator == "and": + compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]] + else: + compare_result = True in [condition["result"] for condition in process_datas["condition_results"]] + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_datas=process_datas, + edge_source_handle="false" if not compare_result else "true", + outputs={ + "result": compare_result + } + ) + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/tests/unit_tests/core/workflow/nodes/if_else_node.py b/api/tests/unit_tests/core/workflow/nodes/if_else_node.py new file mode 100644 index 00000000000000..7b402ad0a09193 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/if_else_node.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +def test_execute_if_else_result_true(): + node = IfElseNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'if-else', + 'data': { + 'title': '123', + 'type': 'if-else', + 'logical_operator': 'and', + 'conditions': [ + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'array_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'array_not_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'not_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'start with', + 'variable_selector': ['start', 'start_with'], + 'value': 'ab' + }, + { + 'comparison_operator': 'end with', + 'variable_selector': ['start', 'end_with'], + 'value': 'ab' + }, + { + 'comparison_operator': 'is', + 'variable_selector': ['start', 'is'], + 'value': 'ab' + }, + { + 'comparison_operator': 'is not', + 'variable_selector': ['start', 'is_not'], + 'value': 'ab' + }, + { + 'comparison_operator': 'empty', + 'variable_selector': ['start', 'empty'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not empty', + 'variable_selector': ['start', 'not_empty'], + 'value': 'ab' + }, + { + 'comparison_operator': '=', + 'variable_selector': ['start', 'equals'], + 'value': '22' + }, + { + 'comparison_operator': '≠', + 'variable_selector': ['start', 'not_equals'], + 'value': '22' + }, + { + 'comparison_operator': '>', + 'variable_selector': ['start', 'greater_than'], + 'value': '22' + }, + { + 'comparison_operator': '<', + 'variable_selector': ['start', 'less_than'], + 'value': '22' + }, + { + 'comparison_operator': '≥', + 'variable_selector': ['start', 'greater_than_or_equal'], + 'value': '22' + }, + { + 'comparison_operator': '≤', + 'variable_selector': ['start', 'less_than_or_equal'], + 'value': '22' + }, + { + 'comparison_operator': 'null', + 'variable_selector': ['start', 'null'] + }, + { + 'comparison_operator': 'not null', + 'variable_selector': ['start', 'not_null'] + }, + ] + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) + pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') + pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') + pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') + pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') + pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') + pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') + pool.append_variable(node_id='start', variable_key_list=['empty'], value='') + pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') + pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) + pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) + pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) + pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) + pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) + pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) + pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] is True + + +def test_execute_if_else_result_false(): + node = IfElseNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'if-else', + 'data': { + 'title': '123', + 'type': 'if-else', + 'logical_operator': 'or', + 'conditions': [ + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'array_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'array_not_contains'], + 'value': 'ab' + } + ] + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def']) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] is False From 6b19ba3bb2821733f4ac1be91266bfde7c0d9eeb Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 13 Mar 2024 17:46:42 +0800 Subject: [PATCH 142/160] enhance: sandbox-docker-compose --- api/.env.example | 4 ++-- docker/docker-compose.middleware.yaml | 3 +++ docker/docker-compose.yaml | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/api/.env.example b/api/.env.example index c0942412ab948f..832c7e3bab6c58 100644 --- a/api/.env.example +++ b/api/.env.example @@ -134,5 +134,5 @@ SSRF_PROXY_HTTPS_URL= BATCH_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION -CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= +CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 +CODE_EXECUTION_API_KEY=dify-sandbox diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 8fba59c3154415..4f7965609b7089 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -55,9 +55,12 @@ services: sandbox: image: langgenius/dify-sandbox:latest restart: always + cap_add: + - SYS_ADMIN environment: # The DifySandbox configurations API_KEY: dify-sandbox + GIN_MODE: 'release' ports: - "8194:8194" diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ca6b6cbf1a17d2..f066582ac8e66e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -293,9 +293,12 @@ services: sandbox: image: langgenius/dify-sandbox:latest restart: always + cap_add: + - SYS_ADMIN environment: # The DifySandbox configurations API_KEY: dify-sandbox + GIN_MODE: release ports: - "8194:8194" From e5ff06bcb78a39691410fcff4e34528040c5b1b3 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 18:02:07 +0800 Subject: [PATCH 143/160] fix err typo --- api/core/workflow/nodes/if_else/if_else_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 9cb084b116dc52..44a4091a2efc6e 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -95,7 +95,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, - process_datas=process_datas, + process_data=process_datas, error=str(e) ) @@ -107,7 +107,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, - process_datas=process_datas, + process_data=process_datas, edge_source_handle="false" if not compare_result else "true", outputs={ "result": compare_result From 0614ddde7dedc0465eb827e40dc170d965f6651a Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 13 Mar 2024 20:40:37 +0800 Subject: [PATCH 144/160] fix: allow None AuthorizationConfig --- .../workflow/nodes/http_request/entities.py | 17 +++++++++-- .../workflow/nodes/test_http.py | 30 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index ce806b6bdbad85..fbd4da384004a4 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,6 +1,6 @@ from typing import Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -17,7 +17,20 @@ class Config(BaseModel): header: Union[None, str] type: Literal['no-auth', 'api-key'] - config: Config + config: Optional[Config] + + @validator('config', always=True, pre=True) + def check_config(cls, v, values): + """ + Check config, if type is no-auth, config should be None, otherwise it should be a dict. + """ + if values['type'] == 'no-auth': + return None + else: + if not v or not isinstance(v, dict): + raise ValueError('config should be a dict') + + return v class Body(BaseModel): type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 6df8f6b6733de2..584e1d80a59d38 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -54,6 +54,36 @@ def test_get(setup_http_mock): assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_no_auth(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'no-auth', + 'config': None, + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'X-Header: 123' in data + @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) def test_template(setup_http_mock): node = HttpRequestNode(config={ From 5a67c09b48d18a398a23b86be5f33c39fcabec0a Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 20:54:23 +0800 Subject: [PATCH 145/160] use answer node instead of end in advanced chatbot --- api/services/workflow/workflow_converter.py | 67 ++++++++++++--------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 4c7e4db47af5f5..78f79e02faca5f 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -19,7 +19,6 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.end.entities import EndNodeOutputType from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account @@ -149,10 +148,13 @@ def convert_app_model_config_to_workflow(self, app_model: App, graph = self._append_node(graph, llm_node) - # convert to end node by app mode - end_node = self._convert_to_end_node(app_model=app_model) - - graph = self._append_node(graph, end_node) + if new_app_mode == AppMode.WORKFLOW: + # convert to end node by app mode + end_node = self._convert_to_end_node() + graph = self._append_node(graph, end_node) + else: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) app_model_config_dict = app_config.app_model_config_dict @@ -517,35 +519,44 @@ def _convert_to_llm_node(self, new_app_mode: AppMode, } } - def _convert_to_end_node(self, app_model: App) -> dict: + def _convert_to_end_node(self) -> dict: """ Convert to End Node - :param app_model: App instance :return: """ - if app_model.mode == AppMode.CHAT.value: - return { - "id": "end", - "position": None, - "data": { - "title": "END", - "type": NodeType.END.value, + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": { + "variable": "result", + "value_selector": ["llm", "text"] } } - elif app_model.mode == AppMode.COMPLETION.value: - # for original completion app - return { - "id": "end", - "position": None, - "data": { - "title": "END", - "type": NodeType.END.value, - "outputs": { - "type": EndNodeOutputType.PLAIN_TEXT.value, - "plain_text_selector": ["llm", "text"] - } - } + } + + def _convert_to_answer_node(self) -> dict: + """ + Convert to Answer Node + :return: + """ + # for original chat app + return { + "id": "answer", + "position": None, + "data": { + "title": "ANSWER", + "type": NodeType.ANSWER.value, + "variables": { + "variable": "text", + "value_selector": ["llm", "text"] + }, + "answer": "{{text}}" } + } def _create_edge(self, source: str, target: str) -> dict: """ @@ -582,7 +593,7 @@ def _get_new_app_mode(self, app_model: App) -> AppMode: if app_model.mode == AppMode.COMPLETION.value: return AppMode.WORKFLOW else: - return AppMode.value_of(app_model.mode) + return AppMode.ADVANCED_CHAT def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: """ From 44c4d5be72d2fcfd2930de377015968f2f75ae22 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 23:00:28 +0800 Subject: [PATCH 146/160] add answer output parse --- .../workflow_event_trigger_callback.py | 31 +--------- api/core/workflow/nodes/answer/answer_node.py | 50 +++++++++++++-- api/core/workflow/nodes/base_node.py | 14 +---- api/core/workflow/nodes/end/end_node.py | 38 +++--------- api/core/workflow/nodes/end/entities.py | 61 +------------------ api/core/workflow/workflow_engine_manager.py | 4 ++ api/services/workflow/workflow_converter.py | 4 +- .../core/workflow/nodes/test_answer.py | 56 +++++++++++++++++ .../{if_else_node.py => test_if_else.py} | 0 9 files changed, 120 insertions(+), 138 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_answer.py rename api/tests/unit_tests/core/workflow/nodes/{if_else_node.py => test_if_else.py} (100%) diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index ea7eb5688cd754..59ef44cd2e4ade 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -5,7 +5,6 @@ QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, @@ -20,7 +19,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self) -> None: """ @@ -118,31 +116,4 @@ def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ - if node_id in self._streamable_node_ids: - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) - - def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param graph: workflow graph - :return: - """ - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('data', {}).get('type') == NodeType.END.value: - if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids + pass diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 381ada1a1e52d0..97ddafad019470 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,4 +1,3 @@ -import time from typing import cast from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -32,14 +31,49 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: variable_values[variable_selector.variable] = value + variable_keys = list(variable_values.keys()) + # format answer template template_parser = PromptTemplateParser(node_data.answer) - answer = template_parser.format(variable_values) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + split_template = [ + { + "type": "var" if self._is_variable(part, variable_keys) else "text", + "value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part + } + for part in template.split('Ω') if part + ] + + answer = [] + for part in split_template: + if part["type"] == "var": + value = variable_values.get(part["value"].replace('{{', '').replace('}}', '')) + answer_part = { + "type": "text", + "text": value + } + # TODO File + else: + answer_part = { + "type": "text", + "text": part["value"] + } - # publish answer as stream - for word in answer: - self.publish_text_chunk(word) - time.sleep(10) # TODO for debug + if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": + answer[-1]["text"] += answer_part["text"] + else: + answer.append(answer_part) + + if len(answer) == 1 and answer[0]["type"] == "text": + answer = answer[0]["text"] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -49,6 +83,10 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: } ) + def _is_variable(self, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index dfba9d0385ee0b..2da19bc409d379 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -6,7 +6,6 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from models.workflow import WorkflowNodeExecutionStatus class UserFrom(Enum): @@ -80,16 +79,9 @@ def run(self, variable_pool: VariablePool) -> NodeRunResult: :param variable_pool: variable pool :return: """ - try: - result = self._run( - variable_pool=variable_pool - ) - except Exception as e: - # process unhandled exception - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) + result = self._run( + variable_pool=variable_pool + ) self.node_run_result = result return result diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2666ccc4f97a04..3241860c298a93 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs +from core.workflow.nodes.end.entities import EndNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -20,34 +20,14 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - outputs_config = node_data.outputs + output_variables = node_data.outputs - outputs = None - if outputs_config: - if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: - plain_text_selector = outputs_config.plain_text_selector - if plain_text_selector: - outputs = { - 'text': variable_pool.get_variable_value( - variable_selector=plain_text_selector, - target_value_type=ValueType.STRING - ) - } - else: - outputs = { - 'text': '' - } - elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: - structured_variables = outputs_config.structured_variables - if structured_variables: - outputs = {} - for variable_selector in structured_variables: - variable_value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - outputs[variable_selector.variable] = variable_value - else: - outputs = {} + outputs = {} + for variable_selector in output_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index 32212ae7faccbc..ad4fc8f04fd43c 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,68 +1,9 @@ -from enum import Enum -from typing import Optional - -from pydantic import BaseModel - from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector -class EndNodeOutputType(Enum): - """ - END Node Output Types. - - none, plain-text, structured - """ - NONE = 'none' - PLAIN_TEXT = 'plain-text' - STRUCTURED = 'structured' - - @classmethod - def value_of(cls, value: str) -> 'OutputType': - """ - Get value of given output type. - - :param value: output type value - :return: output type - """ - for output_type in cls: - if output_type.value == value: - return output_type - raise ValueError(f'invalid output type value {value}') - - -class EndNodeDataOutputs(BaseModel): - """ - END Node Data Outputs. - """ - class OutputType(Enum): - """ - Output Types. - """ - NONE = 'none' - PLAIN_TEXT = 'plain-text' - STRUCTURED = 'structured' - - @classmethod - def value_of(cls, value: str) -> 'OutputType': - """ - Get value of given output type. - - :param value: output type value - :return: output type - """ - for output_type in cls: - if output_type.value == value: - return output_type - raise ValueError(f'invalid output type value {value}') - - type: OutputType = OutputType.NONE - plain_text_selector: Optional[list[str]] = None - structured_variables: Optional[list[VariableSelector]] = None - - class EndNodeData(BaseNodeData): """ END Node Data. """ - outputs: Optional[EndNodeDataOutputs] = None + outputs: list[VariableSelector] diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index ebc753537e172a..3109f9ea330f80 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,3 +1,4 @@ +import logging import time from typing import Optional @@ -41,6 +42,8 @@ NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, } +logger = logging.getLogger(__name__) + class WorkflowEngineManager: def get_default_configs(self) -> list[dict]: @@ -407,6 +410,7 @@ def _run_workflow_node(self, workflow_run_state: WorkflowRunState, variable_pool=workflow_run_state.variable_pool ) except Exception as e: + logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") node_run_result = NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 78f79e02faca5f..953c5c5a3cdb73 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -531,10 +531,10 @@ def _convert_to_end_node(self) -> dict: "data": { "title": "END", "type": NodeType.END.value, - "outputs": { + "outputs": [{ "variable": "result", "value_selector": ["llm", "text"] - } + }] } } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py new file mode 100644 index 00000000000000..bad5d42a43e023 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -0,0 +1,56 @@ +from unittest.mock import MagicMock + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +def test_execute_answer(): + node = AnswerNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'answer', + 'data': { + 'title': '123', + 'type': 'answer', + 'variables': [ + { + 'value_selector': ['llm', 'text'], + 'variable': 'text' + }, + { + 'value_selector': ['start', 'weather'], + 'variable': 'weather' + }, + ], + 'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.' + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny') + pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.') + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + + +# TODO test files diff --git a/api/tests/unit_tests/core/workflow/nodes/if_else_node.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py similarity index 100% rename from api/tests/unit_tests/core/workflow/nodes/if_else_node.py rename to api/tests/unit_tests/core/workflow/nodes/test_if_else.py From 6633a92e1aef02aae56d6c0a1caa11aa3e7671fa Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 11:35:51 +0800 Subject: [PATCH 147/160] fix: http --- .../workflow/nodes/http_request/entities.py | 2 +- .../nodes/http_request/http_executor.py | 6 +- .../nodes/http_request/http_request_node.py | 2 +- .../workflow/nodes/test_http.py | 74 +++++++++++++++++++ 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index fbd4da384004a4..0683008954de26 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -33,7 +33,7 @@ def check_config(cls, v, values): return v class Body(BaseModel): - type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] + type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] data: Union[None, str] variables: list[VariableSelector] diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index c96d5f07d19a82..3d307be0d1f58b 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -131,8 +131,6 @@ def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, An self.headers['Content-Type'] = 'application/json' elif node_data.body.type == 'x-www-form-urlencoded': self.headers['Content-Type'] = 'application/x-www-form-urlencoded' - # elif node_data.body.type == 'form-data': - # self.headers['Content-Type'] = 'multipart/form-data' if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: body = {} @@ -152,8 +150,10 @@ def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, An } else: self.body = urlencode(body) - else: + elif node_data.body.type in ['json', 'raw']: self.body = original_body + elif node_data.body.type == 'none': + self.body = '' def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index c83e331fa8f4c3..a914ae13ff1b0b 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -42,7 +42,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: inputs=variables, outputs={ 'status_code': response.status_code, - 'body': response, + 'body': response.body, 'headers': response.headers }, process_data={ diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 584e1d80a59d38..8b94105b44527f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -84,6 +84,41 @@ def test_no_auth(setup_http_mock): assert '?A=b' in data assert 'X-Header: 123' in data +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_custom_authorization_header(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'custom', + 'api_key': 'Auth', + 'header': 'X-Auth', + }, + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'X-Header: 123' in data + assert 'X-Auth: Auth' in data + @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) def test_template(setup_http_mock): node = HttpRequestNode(config={ @@ -237,3 +272,42 @@ def test_form_data(setup_http_mock): assert '2' in data assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data + +def test_none_data(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'none', + 'data': '123123123' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + assert '123123123' not in data \ No newline at end of file From fb6e5bf4d5f40165ef41c5e850ae50467300aec7 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 11:39:05 +0800 Subject: [PATCH 148/160] fix publish route --- api/controllers/console/app/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6f81da569126f3..d5967dd5ed31ed 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -299,7 +299,7 @@ def generate() -> Generator: api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') +api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' '/') From c2ded79cb2bd752866b42ca8b0a9640da1be9e66 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 11:58:56 +0800 Subject: [PATCH 149/160] fix: node type --- api/core/workflow/nodes/tool/tool_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c62e025e75fed1..89c83890854535 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -136,12 +136,12 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping """ return { k.variable: k.value_selector - for k in cast(ToolNodeData, node_data).tool_parameters + for k in node_data.tool_parameters if k.variable_type == 'selector' } From 87a36a1fc8ba3c88646d884761a5e19b108fcefb Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 11:59:33 +0800 Subject: [PATCH 150/160] fix: linter --- api/core/workflow/nodes/tool/tool_node.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 89c83890854535..b03ad45e6ce8cd 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -5,7 +5,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode From 72d2f76d2444a14ec34f6ba1dbf8d098241a96d3 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 12:12:26 +0800 Subject: [PATCH 151/160] fix default configs --- api/core/workflow/workflow_engine_manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3109f9ea330f80..a7379e6e99fceb 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -54,10 +54,7 @@ def get_default_configs(self) -> list[dict]: for node_type, node_class in node_classes.items(): default_config = node_class.get_default_config() if default_config: - default_block_configs.append({ - 'type': node_type.value, - 'config': default_config - }) + default_block_configs.append(default_config) return default_block_configs From 737321da756dc16cd882f01204129bc0febc567b Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 12:17:15 +0800 Subject: [PATCH 152/160] add advanced chat apis support --- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/conversation.py | 8 ++++---- api/controllers/console/app/message.py | 4 ++-- api/controllers/console/app/statistic.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 4de4a6f3fe82e6..29d89ae4603b75 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -37,7 +37,7 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): file = request.files['file'] diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 33711076f8597c..11dece3a9e58e5 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -112,7 +112,7 @@ def get(self, app_model, conversation_id): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -133,7 +133,7 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -218,7 +218,7 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -227,7 +227,7 @@ def get(self, app_model, conversation_id): @setup_required @login_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 111ec7d787e58a..56d2e718e7d4d1 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -42,7 +42,7 @@ class ChatMessageListApi(Resource): @setup_required @login_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): @@ -194,7 +194,7 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model, message_id): message_id = str(message_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 51fe53c0ec7050..d687b52dc8e6ec 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -203,7 +203,7 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model): account = current_user From 6e51ce123c66feb738eebbe8740e2ebb509612a2 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 12:56:25 +0800 Subject: [PATCH 153/160] fix: null conversation id --- ...nable_tool_file_without_conversation_id.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py new file mode 100644 index 00000000000000..d91288bcf5d7d9 --- /dev/null +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -0,0 +1,36 @@ +"""enable tool file without conversation id + +Revision ID: 563cf8bf777b +Revises: b5429b71023c +Create Date: 2024-03-14 04:54:56.679506 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '563cf8bf777b' +down_revision = 'b5429b71023c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### From 74e644be1ca1436b1c7ef265158fceab761662f4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 12:56:57 +0800 Subject: [PATCH 154/160] fix: linter --- .../563cf8bf777b_enable_tool_file_without_conversation_id.py | 1 - api/models/tools.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py index d91288bcf5d7d9..299f442de989be 100644 --- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -6,7 +6,6 @@ """ from alembic import op -import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/models/tools.py b/api/models/tools.py index bceef7a8290151..4bdf2503ce0619 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -218,7 +218,7 @@ class ToolFile(db.Model): # tenant id tenant_id = db.Column(UUID, nullable=False) # conversation id - conversation_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, nullable=True) # file key file_key = db.Column(db.String(255), nullable=False) # mime type From dc53362506f8453237030b560cf2a8d884f8290b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 13:24:48 +0800 Subject: [PATCH 155/160] fix: conversation_id equals to none --- api/core/workflow/nodes/tool/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b03ad45e6ce8cd..ca217182ccfd5c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -78,7 +78,7 @@ def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str messages=messages, user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id='', + conversation_id=None, ) # extract plain text and files files = self._extract_tool_response_binary(messages) From ede65eca4d9c14484ea1b4674febae2b1ddb20c9 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 16:38:22 +0800 Subject: [PATCH 156/160] fix: tool --- api/core/workflow/nodes/tool/entities.py | 11 +++++++++-- api/core/workflow/nodes/tool/tool_node.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 0b3bf76aacfd2a..7eb3cf655b1bdc 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -3,7 +3,6 @@ from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector ToolParameterValue = Union[str, int, float, bool] @@ -16,8 +15,10 @@ class ToolEntity(BaseModel): tool_configurations: dict[str, ToolParameterValue] class ToolNodeData(BaseNodeData, ToolEntity): - class ToolInput(VariableSelector): + class ToolInput(BaseModel): + variable: str variable_type: Literal['selector', 'static'] + value_selector: Optional[list[str]] value: Optional[str] @validator('value') @@ -25,6 +26,12 @@ def check_value(cls, value, values, **kwargs): if values['variable_type'] == 'static' and value is None: raise ValueError('value is required for static variable') return value + + @validator('value_selector') + def check_value_selector(cls, value_selector, values, **kwargs): + if values['variable_type'] == 'selector' and value_selector is None: + raise ValueError('value_selector is required for selector variable') + return value_selector """ Tool Node Schema diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ca217182ccfd5c..d0bfd9e7973467 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -44,7 +44,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters, - error=f'Failed to invoke tool: {str(e)}' + error=f'Failed to invoke tool: {str(e)}', ) # convert tool messages @@ -56,6 +56,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: 'text': plain_text, 'files': files }, + inputs=parameters ) def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: From 1cfeb989f77fde324866ef9897268cabbcb3c747 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 19:17:27 +0800 Subject: [PATCH 157/160] fix: code default output --- api/core/workflow/nodes/code/code_node.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 5dfe398711528f..0b46f86e9d22e7 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -53,12 +53,12 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: ], "code_language": "javascript", "code": JAVASCRIPT_DEFAULT_CODE, - "outputs": [ - { - "variable": "result", - "variable_type": "number" + "outputs": { + "result": { + "type": "number", + "children": None } - ] + } } } @@ -77,12 +77,12 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: ], "code_language": "python3", "code": PYTHON_DEFAULT_CODE, - "outputs": [ - { - "variable": "result", - "variable_type": "number" + "outputs": { + "result": { + "type": "number", + "children": None } - ] + } } } From 12eb2363646b316999a766633a7cee723e623e06 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 20:49:53 +0800 Subject: [PATCH 158/160] answer stream output support --- .../advanced_chat/generate_task_pipeline.py | 277 +++++++++++++++++- .../workflow_event_trigger_callback.py | 39 +-- .../apps/message_based_app_queue_manager.py | 6 +- .../workflow_event_trigger_callback.py | 2 +- api/core/app/entities/queue_entities.py | 11 +- .../callbacks/base_workflow_callback.py | 2 +- api/core/workflow/nodes/answer/answer_node.py | 119 +++++--- api/core/workflow/nodes/answer/entities.py | 26 ++ api/core/workflow/nodes/base_node.py | 9 +- api/core/workflow/nodes/llm/llm_node.py | 2 +- 10 files changed, 408 insertions(+), 85 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e8463e59d3b5d5..ca4b143027c512 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from pydantic import BaseModel, Extra @@ -13,6 +13,7 @@ InvokeFrom, ) from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, QueueAnnotationReplyEvent, QueueErrorEvent, QueueMessageFileEvent, @@ -34,6 +35,8 @@ from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account @@ -51,15 +54,26 @@ logger = logging.getLogger(__name__) +class StreamGenerateRoute(BaseModel): + """ + StreamGenerateRoute entity + """ + answer_node_id: str + generate_route: list[GenerateRouteChunk] + current_route_position: int = 0 + + class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): """ NodeExecutionInfo entity """ workflow_node_execution_id: str + node_type: NodeType start_at: float class Config: @@ -77,9 +91,11 @@ class Config: total_tokens: int = 0 total_steps: int = 0 - running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} latest_node_execution_info: Optional[NodeExecutionInfo] = None + current_stream_generate_state: Optional[StreamGenerateRoute] = None + class Config: """Configuration for this pydantic object.""" @@ -122,6 +138,11 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream + if stream: + self._stream_generate_routes = self._get_stream_generate_routes() + else: + self._stream_generate_routes = None + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. @@ -290,6 +311,11 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(data) break + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, @@ -309,7 +335,7 @@ def _process_stream_response(self) -> Generator: } yield self._yield_response(workflow_run_response) - + elif isinstance(event, QueueAdvancedChatMessageEndEvent): # response moderation if self._output_moderation_handler: self._output_moderation_handler.stop_thread() @@ -390,6 +416,11 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueTextChunkEvent): + if not self._is_stream_out_support( + event=event + ): + continue + delta_text = event.text if delta_text is None: continue @@ -467,20 +498,28 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: latest_node_execution_info = TaskState.NodeExecutionInfo( workflow_node_execution_id=workflow_node_execution.id, + node_type=event.node_type, start_at=time.perf_counter() ) - self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.total_steps += 1 db.session.close() + # search stream_generate_routes if node id is answer start at node + if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: + self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + + # stream outputs from start + self._generate_stream_outputs_when_node_start() + return workflow_node_execution def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): @@ -508,8 +547,8 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven error=event.error ) - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] + # stream outputs when node finished + self._generate_stream_outputs_when_node_finished() db.session.close() @@ -517,7 +556,8 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ -> WorkflowRun: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_run = (db.session.query(WorkflowRun) + .filter(WorkflowRun.id == self._task_state.workflow_run_id).first()) if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( workflow_run=workflow_run, @@ -642,7 +682,7 @@ def _error_to_stream_response_data(self, e: Exception) -> dict: QuotaExceededError: { 'code': 'provider_quota_exceeded', 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", + "Please go to Settings -> Model Provider to complete your own provider credentials.", 'status': 400 }, ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, @@ -660,10 +700,10 @@ def _error_to_stream_response_data(self, e: Exception) -> dict: else: logging.error(e) data = { - 'code': 'internal_server_error', + 'code': 'internal_server_error', 'message': 'Internal Server Error, please contact support.', 'status': 500 - } + } return { 'event': 'error', @@ -730,3 +770,218 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: ), queue_manager=self._queue_manager ) + + def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]: + """ + Get stream generate routes. + :return: + """ + # find all answer nodes + graph = self._workflow.graph_dict + answer_node_configs = [ + node for node in graph['nodes'] + if node.get('data', {}).get('type') == NodeType.ANSWER.value + ] + + # parse stream output node value selectors of answer nodes + stream_generate_routes = {} + for node_config in answer_node_configs: + # get generate route for stream output + answer_node_id = node_config['id'] + generate_route = AnswerNode.extract_generate_route_selectors(node_config) + start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id) + if not start_node_id: + continue + + stream_generate_routes[start_node_id] = StreamGenerateRoute( + answer_node_id=answer_node_id, + generate_route=generate_route + ) + + return stream_generate_routes + + def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \ + -> Optional[str]: + """ + Get answer start at node id. + :param graph: graph + :param target_node_id: target node ID + :return: + """ + nodes = graph.get('nodes') + edges = graph.get('edges') + + # fetch all ingoing edges from source node + ingoing_edge = None + for edge in edges: + if edge.get('target') == target_node_id: + ingoing_edge = edge + break + + if not ingoing_edge: + return None + + source_node_id = ingoing_edge.get('source') + source_node = next((node for node in nodes if node.get('id') == source_node_id), None) + if not source_node: + return None + + node_type = source_node.get('data', {}).get('type') + if node_type in [ + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER + ]: + start_node_id = target_node_id + elif node_type == NodeType.START.value: + start_node_id = source_node_id + else: + start_node_id = self._get_answer_start_at_node_id(graph, source_node_id) + + return start_node_id + + def _generate_stream_outputs_when_node_start(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + for route_chunk in self._task_state.current_stream_generate_state.generate_route: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + else: + break + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _generate_stream_outputs_when_node_finished(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + route_chunks = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + route_chunk_node_id = value_selector[0] + + # check chunk node id is before current node id or equal to current node id + if route_chunk_node_id not in self._task_state.ran_node_execution_infos: + break + + latest_node_execution_info = self._task_state.latest_node_execution_info + + # get route chunk node execution info + route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] + if (route_chunk_node_execution_info.node_type == NodeType.LLM + and latest_node_execution_info.node_type == NodeType.LLM): + # only LLM support chunk stream output + self._task_state.current_stream_generate_state.current_route_position += 1 + continue + + # get route chunk node execution + route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first() + + outputs = route_chunk_node_execution.outputs_dict + + # get value from outputs + value = None + for key in value_selector[1:]: + if not value: + value = outputs.get(key) + else: + value = value.get(key) + + if value: + text = None + if isinstance(value, str | int | float): + text = str(value) + elif isinstance(value, object): # TODO FILE + # convert file to markdown + text = f'![]({value.get("url")})' + pass + + if text: + for token in text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.metadata: + return True + + if 'node_id' not in event.metadata: + return True + + node_type = event.metadata.get('node_type') + stream_output_value_selector = event.metadata.get('value_selector') + if not stream_output_value_selector: + return False + + if not self._task_state.current_stream_generate_state: + return False + + route_chunk = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position] + + if route_chunk.type != 'var': + return False + + if node_type != NodeType.LLM: + # only LLM support chunk stream output + return False + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + return False + + return True diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index b4a6a9602f6c51..972fda2d49a66c 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self) -> None: """ @@ -114,34 +113,16 @@ def on_workflow_node_execute_failed(self, node_id: str, PublishFrom.APPLICATION_MANAGER ) - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ - if node_id in self._streamable_node_ids: - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) - - def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param graph: workflow graph - :return: - """ - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('data', {}).get('type') == NodeType.END.value: - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids + self._queue_manager.publish( + QueueTextChunkEvent( + text=text, + metadata={ + "node_id": node_id, + **metadata + } + ), PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 6d0a71f495e328..f4ff44dddac9ef 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -3,12 +3,11 @@ from core.app.entities.queue_entities import ( AppQueueEvent, MessageQueueMessage, + QueueAdvancedChatMessageEndEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent, QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, ) @@ -54,8 +53,7 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + | QueueAdvancedChatMessageEndEvent): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 59ef44cd2e4ade..e5a8e8d3747c42 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -112,7 +112,7 @@ def on_workflow_node_execute_failed(self, node_id: str, PublishFrom.APPLICATION_MANAGER ) - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 153607e1b4473e..5c31996fd345a6 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -17,6 +17,7 @@ class QueueEvent(Enum): AGENT_MESSAGE = "agent_message" MESSAGE_REPLACE = "message_replace" MESSAGE_END = "message_end" + ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end" WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" @@ -53,6 +54,7 @@ class QueueTextChunkEvent(AppQueueEvent): """ event = QueueEvent.TEXT_CHUNK text: str + metadata: Optional[dict] = None class QueueAgentMessageEvent(AppQueueEvent): @@ -92,7 +94,14 @@ class QueueMessageEndEvent(AppQueueEvent): QueueMessageEndEvent entity """ event = QueueEvent.MESSAGE_END - llm_result: LLMResult + llm_result: Optional[LLMResult] = None + + +class QueueAdvancedChatMessageEndEvent(AppQueueEvent): + """ + QueueAdvancedChatMessageEndEvent entity + """ + event = QueueEvent.ADVANCED_CHAT_MESSAGE_END class QueueWorkflowStartedEvent(AppQueueEvent): diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 9594fa20372064..1f5472b430c96a 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -64,7 +64,7 @@ def on_workflow_node_execute_failed(self, node_id: str, raise NotImplementedError @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 97ddafad019470..d8ff5cb6f630d1 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -4,7 +4,12 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool -from core.workflow.nodes.answer.entities import AnswerNodeData +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) from core.workflow.nodes.base_node import BaseNode from models.workflow import WorkflowNodeExecutionStatus @@ -22,49 +27,29 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - variable_values = {} - for variable_selector in node_data.variables: - value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector, - target_value_type=ValueType.STRING - ) - - variable_values[variable_selector.variable] = value - - variable_keys = list(variable_values.keys()) - - # format answer template - template_parser = PromptTemplateParser(node_data.answer) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - split_template = [ - { - "type": "var" if self._is_variable(part, variable_keys) else "text", - "value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part - } - for part in template.split('Ω') if part - ] + # generate routes + generate_routes = self.extract_generate_route_from_node_data(node_data) answer = [] - for part in split_template: - if part["type"] == "var": - value = variable_values.get(part["value"].replace('{{', '').replace('}}', '')) + for part in generate_routes: + if part.type == "var": + part = cast(VarGenerateRouteChunk, part) + value_selector = part.value_selector + value = variable_pool.get_variable_value( + variable_selector=value_selector, + target_value_type=ValueType.STRING + ) + answer_part = { "type": "text", "text": value } # TODO File else: + part = cast(TextGenerateRouteChunk, part) answer_part = { "type": "text", - "text": part["value"] + "text": part.text } if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": @@ -75,6 +60,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: if len(answer) == 1 and answer[0]["type"] == "text": answer = answer[0]["text"] + # re-fetch variable values + variable_values = {} + for variable_selector in node_data.variables: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + variable_values[variable_selector.variable] = value + return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variable_values, @@ -83,7 +78,61 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: } ) - def _is_variable(self, part, variable_keys): + @classmethod + def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + node_data = cast(cls._node_data_cls, node_data) + + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(node_data.answer) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _is_variable(cls, part, variable_keys): cleaned_part = part.replace('{{', '').replace('}}', '') return part.startswith('{{') and cleaned_part in variable_keys diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 7c6fed3e4ea6f4..8aed752ccb55e6 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,3 +1,6 @@ + +from pydantic import BaseModel + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData): """ variables: list[VariableSelector] = [] answer: str + + +class GenerateRouteChunk(BaseModel): + """ + Generate Route Chunk. + """ + type: str + + +class VarGenerateRouteChunk(GenerateRouteChunk): + """ + Var Generate Route Chunk. + """ + type: str = "var" + value_selector: list[str] + + +class TextGenerateRouteChunk(GenerateRouteChunk): + """ + Text Generate Route Chunk. + """ + type: str = "text" + text: str diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 2da19bc409d379..7cc9c6ee3dba81 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -86,17 +86,22 @@ def run(self, variable_pool: VariablePool) -> NodeRunResult: self.node_run_result = result return result - def publish_text_chunk(self, text: str) -> None: + def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: """ Publish text chunk :param text: chunk text + :param value_selector: value selector :return: """ if self.callbacks: for callback in self.callbacks: callback.on_node_text_chunk( node_id=self.node_id, - text=text + text=text, + metadata={ + "node_type": self.node_type, + "value_selector": value_selector + } ) @classmethod diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 9285bbe74e87a1..cb5a33309141dd 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -169,7 +169,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text) + self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) if not model: model = result.model From 785dfc5c0085cc93e64480d5f03b7eaf00c1b57c Mon Sep 17 00:00:00 2001 From: jyong Date: Fri, 15 Mar 2024 14:40:53 +0800 Subject: [PATCH 159/160] dataset retrival --- .../dataset_multi_retriever_tool.py | 194 ++++++++++ .../dataset_retriever_tool.py | 159 ++++++++ .../nodes/knowledge_retrieval/entities.py | 52 +++ .../knowledge_retrieval.py | 0 .../knowledge_retrieval_node.py | 364 +++++++++++++++++- 5 files changed, 766 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/entities.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py new file mode 100644 index 00000000000000..d9934acff9c619 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py @@ -0,0 +1,194 @@ +import threading +from typing import Optional + +from flask import Flask, current_app +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(BaseTool): + """Tool for querying multi dataset.""" + name: str = "dataset-" + args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + tenant_id: str + dataset_ids: list[str] + top_k: int = 2 + score_threshold: Optional[float] = None + reranking_provider_name: str + reranking_model_name: str + return_resource: bool + retriever_from: str + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + + @classmethod + def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): + return cls( + name=f'dataset-{tenant_id}', + tenant_id=tenant_id, + dataset_ids=dataset_ids, + **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'all_documents': all_documents, + 'hit_callbacks': self.hit_callbacks + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.reranking_provider_name, + model_type=ModelType.RERANK, + model=self.reranking_model_name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(all_documents) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + } + + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler]): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + for hit_callback in hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + if documents: + all_documents.extend(documents) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py new file mode 100644 index 00000000000000..13331d981bbecf --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py @@ -0,0 +1,159 @@ +from typing import Optional + +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.datasource.retrieval_service import RetrievalService +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetRetrieverToolInput(BaseModel): + query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") + + +class DatasetRetrieverTool(BaseTool): + """Tool for querying a Dataset.""" + name: str = "dataset" + args_schema: type[BaseModel] = DatasetRetrieverToolInput + description: str = "use this to retrieve a dataset. " + + tenant_id: str + dataset_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs): + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + return cls( + name=f'dataset-{dataset.id}', + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + description=description, + **kwargs + ) + + def _run(self, query: str) -> str: + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == self.dataset_id + ).first() + + if not dataset: + return '' + + for hit_callback in self.hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in documents] + segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + + } + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py new file mode 100644 index 00000000000000..905ee1f80da0e3 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -0,0 +1,52 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + provider: str + mode: str + + +class MultipleRetrievalConfig(BaseModel): + """ + Multiple Retrieval Config. + """ + top_k: int + score_threshold: Optional[float] + reranking_model: RerankingModelConfig + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class SingleRetrievalConfig(BaseModel): + """ + Single Retrieval Config. + """ + model: ModelConfig + + +class KnowledgeRetrievalNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + variables: list[VariableSelector] + dataset_ids: list[str] + retrieval_mode: Literal['single', 'multiple'] + multiple_retrieval_config: MultipleRetrievalConfig + singleRetrievalConfig: SingleRetrievalConfig diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7b8344418b2fa0..1ccdbf971c1ce4 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,13 +1,371 @@ +import threading +from typing import cast, Any + +from flask import current_app, Flask + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment, Document +from models.workflow import WorkflowNodeExecutionStatus +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} class KnowledgeRetrievalNode(BaseNode): + + _node_data_cls = KnowledgeRetrievalNodeData + _node_type = NodeType.TOOL + def _run(self, variable_pool: VariablePool) -> NodeRunResult: - pass + node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # retrieve knowledge + try: + outputs = self._fetch_dataset_retriever( + node_data=node_data, variables=variables + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs + ) + + except Exception as e: + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e) + ) + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]: + """ + A dataset tool is a tool that can be used to retrieve information from a dataset + :param node_data: node data + :param variables: variables + """ + tools = [] + available_datasets = [] + dataset_ids = node_data.dataset_ids + for dataset_id in dataset_ids: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + # pass if dataset is not available + if not dataset: + continue + + # pass if dataset is not available + if (dataset and dataset.available_document_count == 0 + and dataset.available_document_count == 0): + continue + + available_datasets.append(dataset) + all_documents = [] + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + all_documents = self._single_retrieve(available_datasets, node_data, variables) + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + all_documents = self._multiple_retrieve(available_datasets, node_data, variables) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + context_list = [] + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + + source = { + 'metadata': { + '_source': 'knowledge', + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'document_data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': 'workflow', + 'score': document_score_list.get(segment.index_node_id, None), + 'segment_hit_count': segment.hit_count, + 'segment_word_count': segment.word_count, + 'segment_position': segment.position + } + } + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + + return context_list @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - pass + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } + + def _single_retrieve(self, available_datasets, node_data, variables): + tools = [] + for dataset in available_datasets: + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + message_tool = PromptMessageTool( + name=dataset.id, + description=description, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + tools.append(message_tool) + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + prompt_messages = [ + SystemPromptMessage(content='You are a helpful AI assistant.'), + UserPromptMessage(content=variables['#query#']) + ] + result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=tools, + stream=False, + model_parameters={ + 'temperature': 0.2, + 'top_p': 0.3, + 'max_tokens': 1500 + } + ) + + if result.message.tool_calls: + # get retrieval model config + function_call_name = result.message.tool_calls[0].function.name + dataset = db.session.query(Dataset).filter( + Dataset.id == function_call_name + ).first() + if dataset: + retrieval_model_config = dataset.retrieval_model \ + if dataset.retrieval_model else default_retrieval_model + + # get top k + top_k = retrieval_model_config['top_k'] + # get retrieval method + retrival_method = retrieval_model_config['search_method'] + # get reranking model + reranking_model = retrieval_model_config['reranking_model'] + # get score threshold + score_threshold = .0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + return results + + + + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.singleRetrievalConfig.model.name + provider_name = node_data.singleRetrievalConfig.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.singleRetrievalConfig.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.singleRetrievalConfig.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _multiple_retrieve(self, available_datasets, node_data, variables): + threads = [] + all_documents = [] + dataset_ids = [dataset.id for dataset in available_datasets] + for dataset in available_datasets: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset.id, + 'query': variables['#query#'], + 'top_k': node_data.multiple_retrieval_config.top_k, + 'all_documents': all_documents, + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=node_data.multiple_retrieval_config.reranking_model.provider, + model_type=ModelType.RERANK, + model=node_data.multiple_retrieval_config.reranking_model.name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(variables['#query#'], all_documents, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.top_k) + + return all_documents + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file From 9b57b4c6c8591473a4b6f88b3e6d58f0ab5ace53 Mon Sep 17 00:00:00 2001 From: jyong Date: Fri, 15 Mar 2024 16:14:32 +0800 Subject: [PATCH 160/160] dataset retrival --- .../knowledge_retrieval_node.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1ccdbf971c1ce4..a501113dc313b8 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -33,10 +33,10 @@ 'score_threshold_enabled': False } -class KnowledgeRetrievalNode(BaseNode): +class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData - _node_type = NodeType.TOOL + _node_type = NodeType.KNOWLEDGE_RETRIEVAL def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) @@ -67,7 +67,9 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: inputs=variables, error=str(e) ) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]: + + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[ + dict[str, Any]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param node_data: node data @@ -224,14 +226,14 @@ def _single_retrieve(self, available_datasets, node_data, variables): if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'], - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, + query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) return results - - - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ + ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -333,7 +335,7 @@ def _multiple_retrieve(self, available_datasets, node_data, variables): return all_documents - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): dataset = db.session.query(Dataset).filter( Dataset.tenant_id == self.tenant_id, @@ -368,4 +370,4 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, if retrieval_model['reranking_enable'] else None ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents)