diff --git a/.github/workflows/tool-tests.yaml b/.github/workflows/api-tools-tests.yaml similarity index 100% rename from .github/workflows/tool-tests.yaml rename to .github/workflows/api-tools-tests.yaml diff --git a/.github/workflows/api-workflow-tests.yaml b/.github/workflows/api-workflow-tests.yaml new file mode 100644 index 00000000000000..37a138b44dc1e0 --- /dev/null +++ b/.github/workflows/api-workflow-tests.yaml @@ -0,0 +1,31 @@ +name: Run Pytest + +on: + pull_request: + branches: + - main + - deploy/dev + +jobs: + test: + runs-on: ubuntu-latest + + env: + MOCK_SWITCH: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: ./api/requirements.txt + + - name: Install dependencies + run: pip install -r ./api/requirements.txt + + - name: Run pytest + run: pytest api/tests/integration_tests/workflow \ No newline at end of file diff --git a/api/.env.example b/api/.env.example index 32d89d4287c599..832c7e3bab6c58 100644 --- a/api/.env.example +++ b/api/.env.example @@ -132,3 +132,7 @@ SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= BATCH_UPLOAD_LIMIT=10 + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 +CODE_EXECUTION_API_KEY=dify-sandbox diff --git a/api/commands.py b/api/commands.py index 250039a3650c2b..376a394d1e18e6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,7 +15,7 @@ from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, MessageAnnotation +from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel @@ -370,8 +370,62 @@ def migrate_knowledge_vector_database(): fg='green')) +@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +def convert_to_agent_apps(): + """ + Convert Agent Assistant to Agent App. + """ + click.echo(click.style('Start convert to agent apps.', fg='green')) + + proceeded_app_ids = [] + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id AS id FROM apps a + INNER JOIN app_model_configs am ON a.app_model_config_id=am.id + WHERE a.mode = 'chat' AND am.agent_mode is not null + and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') + and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.query(App).filter(App.id == app_id).first() + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo('Converting app: {}'.format(app.id)) + + try: + app.mode = AppMode.AGENT_CHAT.value + db.session.commit() + + # update conversation mode to agent + db.session.query(Conversation).filter(Conversation.app_id == app.id).update( + {Conversation.mode: AppMode.AGENT_CHAT.value} + ) + + db.session.commit() + click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + except Exception as e: + click.echo( + click.style('Convert app error: {} {}'.format(e.__class__.__name__, + str(e)), fg='red')) + + click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) + app.cli.add_command(convert_to_agent_apps) diff --git a/api/config.py b/api/config.py index a978a099b9ed4d..a4ec6fcef95731 100644 --- a/api/config.py +++ b/api/config.py @@ -27,6 +27,7 @@ 'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'DEPLOY_ENV': 'PRODUCTION', 'SQLALCHEMY_POOL_SIZE': 30, + 'SQLALCHEMY_MAX_OVERFLOW': 10, 'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_ECHO': 'False', 'SENTRY_TRACES_SAMPLE_RATE': 1.0, @@ -59,7 +60,9 @@ 'CAN_REPLACE_LOGO': 'False', 'ETL_TYPE': 'dify', 'KEYWORD_STORE': 'jieba', - 'BATCH_UPLOAD_LIMIT': 20 + 'BATCH_UPLOAD_LIMIT': 20, + 'CODE_EXECUTION_ENDPOINT': '', + 'CODE_EXECUTION_API_KEY': '' } @@ -146,6 +149,7 @@ def __init__(self): self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}" self.SQLALCHEMY_ENGINE_OPTIONS = { 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), + 'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')), 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) } @@ -293,6 +297,9 @@ def __init__(self): self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') + self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') + self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') diff --git a/api/constants/languages.py b/api/constants/languages.py index 0ae69d77d20283..dd8a29eaef3944 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,6 +1,4 @@ -import json -from models.model import AppModelConfig languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN'] @@ -91,512 +89,3 @@ def supported_language(lang): } ], } - -demo_model_templates = { - 'en-US': [ - { - 'name': 'Translation Assistant', - 'icon': '', - 'icon_background': '', - 'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "Please translate the following text into {{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "Target Language", - "description": "The language you want to translate into.", - "type": "select", - "default": "Chinese", - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Target Language", - "variable": "target_language", - "description": "The language you want to translate into.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - }, { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'AI Front-end Interviewer', - 'icon': '', - 'icon_background': '', - 'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', - 'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', - suggested_questions=None, - pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], - 'zh-Hans': [ - { - 'name': '翻译助手', - 'icon': '', - 'icon_background': '', - 'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "请将以下文本翻译为{{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "目标语言", - "description": "翻译的目标语言", - "type": "select", - "default": "中文", - "options": [ - "中文", - "英文", - "日语", - "法语", - "俄语", - "德语", - "西班牙语", - "韩语", - "意大利语", - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "目标语言", - "variable": "target_language", - "description": "翻译的目标语言", - "default": "中文", - "required": True, - 'options': [ - "中文", - "英文", - "日语", - "法语", - "俄语", - "德语", - "西班牙语", - "韩语", - "意大利语", - ] - } - }, { - "paragraph": { - "label": "文本内容", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'AI 前端面试官', - 'icon': '', - 'icon_background': '', - 'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', - 'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', - suggested_questions=None, - pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], - 'uk-UA': [ - { - "name": "Помічник перекладу", - "icon": "", - "icon_background": "", - "description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.", - "mode": "completion", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo-instruct", - configs={ - "prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n", - "prompt_variables": [ - { - "key": "target_language", - "name": "Цільова мова", - "description": "Мова, на яку ви хочете перекласти.", - "type": "select", - "default": "Ukrainian", - "options": [ - "Chinese", - "English", - "Japanese", - "French", - "Russian", - "German", - "Spanish", - "Korean", - "Italian", - ], - }, - ], - "completion_params": { - "max_token": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }, - opening_statement="", - suggested_questions=None, - pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Цільова мова", - "variable": "target_language", - "description": "Мова, на яку ви хочете перекласти.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - }, { - "paragraph": { - "label": "Запит", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - "name": "AI інтерв’юер фронтенду", - "icon": "", - "icon_background": "", - "description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.", - "mode": "chat", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo", - configs={ - "introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", - "prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", - "prompt_variables": [], - "completion_params": { - "max_token": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }, - opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", - suggested_questions=None, - pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=None - ), - } - ], - 'vi-VN': [ - { - 'name': 'Trợ lý dịch thuật', - 'icon': '', - 'icon_background': '', - 'description': 'Trình dịch đa ngôn ngữ cung cấp khả năng dịch bằng nhiều ngôn ngữ, dịch thông tin đầu vào của người dùng sang ngôn ngữ họ cần.', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "Hãy dịch đoạn văn bản sau sang ngôn ngữ {{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "Ngôn ngữ đích", - "description": "Ngôn ngữ bạn muốn dịch sang.", - "type": "select", - "default": "Vietnamese", - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - 'Vietnamese', - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="Hãy dịch đoạn văn bản sau sang {{target_language}}:\n{{query}}\ndịch:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Ngôn ngữ đích", - "variable": "target_language", - "description": "Ngôn ngữ bạn muốn dịch sang.", - "default": "Vietnamese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - 'Vietnamese', - ] - } - }, { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'Phỏng vấn front-end AI', - 'icon': '', - 'icon_background': '', - 'description': 'Một người phỏng vấn front-end mô phỏng để kiểm tra mức độ kỹ năng phát triển front-end thông qua việc đặt câu hỏi.', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': 'Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', - 'prompt_template': "Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', - suggested_questions=None, - pre_prompt="Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], -} diff --git a/api/constants/model_template.py b/api/constants/model_template.py index d87f7c392610f7..c8aaba23cb83ed 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,64 +1,55 @@ -import json +from models.model import AppMode -model_templates = { - # completion default mode - 'completion_default': { +default_app_templates = { + # workflow default mode + AppMode.WORKFLOW: { 'app': { - 'mode': 'completion', + 'mode': AppMode.WORKFLOW.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True + } + }, + + # chat default mode + AppMode.CHAT: { + 'app': { + 'mode': AppMode.CHAT.value, + 'enable_site': True, + 'enable_api': True }, 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {}, - 'model': json.dumps({ + 'model': { "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", + "name": "gpt-4", + "mode": "chat", "completion_params": {} - }), - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + } } }, - # chat default mode - 'chat_default': { + # advanced-chat default mode + AppMode.ADVANCED_CHAT: { 'app': { - 'mode': 'chat', + 'mode': AppMode.ADVANCED_CHAT.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True + } + }, + + # agent-chat default mode + AppMode.AGENT_CHAT: { + 'app': { + 'mode': AppMode.AGENT_CHAT.value, + 'enable_site': True, + 'enable_api': True }, 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {}, - 'model': json.dumps({ + 'model': { "provider": "openai", - "name": "gpt-3.5-turbo", + "name": "gpt-4", "mode": "chat", "completion_params": {} - }) + } } - }, + } } - - diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ecfdc38612ce64..853ca9e3a7ca4d 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -5,10 +5,10 @@ api = ExternalApi(bp) # Import other controllers -from . import admin, apikey, extension, feature, setup, version +from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic) + model_config, site, statistic, workflow, workflow_run, workflow_app_log) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py index b0b07517f10aad..e69de29bb2d1d6 100644 --- a/api/controllers/console/app/__init__.py +++ b/api/controllers/console/app/__init__.py @@ -1,21 +0,0 @@ -from controllers.console.app.error import AppUnavailableError -from extensions.ext_database import db -from flask_login import current_user -from models.model import App -from werkzeug.exceptions import NotFound - - -def _get_app(app_id, mode=None): - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() - - if not app: - raise NotFound("App not found") - - if mode and app.mode != mode: - raise NotFound("The {} app not found".format(mode)) - - return app diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ff974054155f22..94406030697458 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,41 +1,28 @@ import json -import logging -from datetime import datetime from flask_login import current_user -from flask_restful import Resource, abort, inputs, marshal_with, reqparse -from werkzeug.exceptions import Forbidden +from flask_restful import Resource, inputs, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, BadRequest -from constants.languages import demo_model_templates, languages -from constants.model_template import model_templates from controllers.console import api -from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.provider_manager import ProviderManager -from events.app_event import app_was_created, app_was_deleted +from core.agent.entities import AgentToolEntity from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, - template_list_fields, ) from libs.login import login_required -from models.model import App, AppModelConfig, Site -from services.app_model_config_service import AppModelConfigService +from services.app_service import AppService +from models.model import App, AppModelConfig, AppMode from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager -from core.entities.application_entities import AgentToolEntity -def _get_app(app_id, tenant_id): - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() - if not app: - raise AppNotFoundError - return app + +ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] class AppListApi(Resource): @@ -49,33 +36,15 @@ def get(self): parser = reqparse.RequestParser() parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'completion', 'all'], default='all', location='args', required=False) + parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() - filters = [ - App.tenant_id == current_user.current_tenant_id, - App.is_universal == False - ] - - if args['mode'] == 'completion': - filters.append(App.mode == 'completion') - elif args['mode'] == 'chat': - filters.append(App.mode == 'chat') - else: - pass - - if 'name' in args and args['name']: - filters.append(App.name.ilike(f'%{args["name"]}%')) + # get app list + app_service = AppService() + app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) - app_models = db.paginate( - db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False - ) - - return app_models + return app_pagination @setup_required @login_required @@ -86,147 +55,49 @@ def post(self): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('model_config', type=dict, location='json') args = parser.parse_args() # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - try: - provider_manager = ProviderManager() - default_model_entity = provider_manager.get_default_model( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except (ProviderTokenNotInitError, LLMBadRequestError): - default_model_entity = None - except Exception as e: - logging.exception(e) - default_model_entity = None - - if args['model_config'] is not None: - # validate config - model_config_dict = args['model_config'] - - # Get provider configurations - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id) - - # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=ModelType.LLM, - only_active=True - ) - - # check if model is available - available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models] - provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}" - if provider_model not in available_models_names: - if not default_model_entity: - raise ProviderNotInitializeError( - "No Default System Reasoning Model available. Please configure " - "in the Settings -> Model Provider.") - else: - model_config_dict["model"]["provider"] = default_model_entity.provider.provider - model_config_dict["model"]["name"] = default_model_entity.model - - model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - account=current_user, - config=model_config_dict, - app_mode=args['mode'] - ) - - app = App( - enable_site=True, - enable_api=True, - is_demo=False, - api_rpm=0, - api_rph=0, - status='normal' - ) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_configuration) - else: - if 'mode' not in args or args['mode'] is None: - abort(400, message="mode is required") - - model_config_template = model_templates[args['mode'] + '_default'] - - app = App(**model_config_template['app']) - app_model_config = AppModelConfig(**model_config_template['model_config']) - - # get model provider - model_manager = ModelManager() - - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None - - if model_instance: - model_dict = app_model_config.model_dict - model_dict['provider'] = model_instance.provider - model_dict['name'] = model_instance.model - app_model_config.model = json.dumps(model_dict) - - app.name = args['name'] - app.mode = args['mode'] - app.icon = args['icon'] - app.icon_background = args['icon_background'] - app.tenant_id = current_user.current_tenant_id - - db.session.add(app) - db.session.flush() - - app_model_config.app_id = app.id - db.session.add(app_model_config) - db.session.flush() - - app.app_model_config_id = app_model_config.id - - account = current_user - - site = Site( - app_id=app.id, - title=app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(app) + if 'mode' not in args or args['mode'] is None: + raise BadRequest("mode is required") + + app_service = AppService() + app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 - -class AppTemplateApi(Resource): +class AppImportApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(template_list_fields) - def get(self): - """Get app demo templates""" - account = current_user - interface_language = account.interface_language + @marshal_with(app_detail_fields_with_site) + @cloud_edition_billing_resource_check('apps') + def post(self): + """Import app""" + # The role of the current user in the ta table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() - templates = demo_model_templates.get(interface_language) - if not templates: - templates = demo_model_templates.get(languages[0]) + parser = reqparse.RequestParser() + parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument('name', type=str, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() - return {'data': templates} + app_service = AppService() + app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user) + + return app, 201 class AppApi(Resource): @@ -234,213 +105,199 @@ class AppApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields_with_site) - def get(self, app_id): + def get(self, app_model): """Get app detail""" - app_id = str(app_id) - app: App = _get_app(app_id, current_user.current_tenant_id) - # get original app model config - model_config: AppModelConfig = app.app_model_config - agent_mode = model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - masked_parameter = {} - - # override tool parameters - tool['tool_parameters'] = masked_parameter - except Exception as e: - pass - - # override agent mode - model_config.agent_mode = json.dumps(agent_mode) - - return app + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + model_config: AppModelConfig = app_model.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + agent_tool_entity = AgentToolEntity(**tool) + # get tool + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} + + # override tool parameters + tool['tool_parameters'] = masked_parameter + except Exception as e: + pass + + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) + db.session.commit() + + return app_model + + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_detail_fields_with_site) + def put(self, app_model): + """Update app""" + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + app_service = AppService() + app_model = app_service.update_app(app_model, args) + + return app_model @setup_required @login_required @account_initialization_required - def delete(self, app_id): + @get_app_model + def delete(self, app_model): """Delete app""" - app_id = str(app_id) + if not current_user.is_admin_or_owner: + raise Forbidden() + + app_service = AppService() + app_service.delete_app(app_model) + + return {'result': 'success'}, 204 + +class AppCopyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_detail_fields_with_site) + def post(self, app_model): + """Copy app""" + # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - app = _get_app(app_id, current_user.current_tenant_id) + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() - db.session.delete(app) - db.session.commit() + app_service = AppService() + data = app_service.export_app(app_model) + app = app_service.import_app(current_user.current_tenant_id, data, args, current_user) - # todo delete related data?? - # model_config, site, api_token, conversation, message, message_feedback, message_annotation + return app, 201 - app_was_deleted.send(app) - return {'result': 'success'}, 204 +class AppExportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): + """Export app""" + app_service = AppService() + + return { + "data": app_service.export_app(app_model) + } class AppNameApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() - app.name = args.get('name') - app.updated_at = datetime.utcnow() - db.session.commit() - return app + app_service = AppService() + app_model = app_service.update_app_name(app_model, args.get('name')) + + return app_model class AppIconApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') - app.updated_at = datetime.utcnow() - db.session.commit() + app_service = AppService() + app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) - return app + return app_model class AppSiteStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_site', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id).first() - if not app: - raise AppNotFoundError - if args.get('enable_site') == app.enable_site: - return app + app_service = AppService() + app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) - app.enable_site = args.get('enable_site') - app.updated_at = datetime.utcnow() - db.session.commit() - return app + return app_model class AppApiStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_api', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - - if args.get('enable_api') == app.enable_api: - return app - - app.enable_api = args.get('enable_api') - app.updated_at = datetime.utcnow() - db.session.commit() - return app - - -class AppCopy(Resource): - @staticmethod - def create_app_copy(app): - copy_app = App( - name=app.name + ' copy', - icon=app.icon, - icon_background=app.icon_background, - tenant_id=app.tenant_id, - mode=app.mode, - app_model_config_id=app.app_model_config_id, - enable_site=app.enable_site, - enable_api=app.enable_api, - api_rpm=app.api_rpm, - api_rph=app.api_rph - ) - return copy_app - - @staticmethod - def create_app_model_config_copy(app_config, copy_app_id): - copy_app_model_config = app_config.copy() - copy_app_model_config.app_id = copy_app_id - - return copy_app_model_config - - @setup_required - @login_required - @account_initialization_required - @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - - copy_app = self.create_app_copy(app) - db.session.add(copy_app) - - app_config = db.session.query(AppModelConfig). \ - filter(AppModelConfig.app_id == app_id). \ - one_or_none() - - if app_config: - copy_app_model_config = self.create_app_model_config_copy(app_config, copy_app.id) - db.session.add(copy_app_model_config) - db.session.commit() - copy_app.app_model_config_id = copy_app_model_config.id - db.session.commit() + app_service = AppService() + app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) - return copy_app, 201 + return app_model api.add_resource(AppListApi, '/apps') -api.add_resource(AppTemplateApi, '/app-templates') +api.add_resource(AppImportApi, '/apps/import') api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopy, '/apps//copy') +api.add_resource(AppCopyApi, '/apps//copy') +api.add_resource(AppExportApi, '/apps//export') api.add_resource(AppNameApi, '/apps//name') api.add_resource(AppIconApi, '/apps//icon') api.add_resource(AppSiteStatus, '/apps//site-enable') diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 77eaf136fc644f..29d89ae4603b75 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,7 +6,6 @@ import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -18,11 +17,13 @@ ProviderQuotaExceededError, UnsupportedAudioTypeError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required +from models.model import AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -36,15 +37,13 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) + def post(self, app_model): file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=None, ) @@ -80,15 +79,13 @@ class ChatMessageTextApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, None) - + @get_app_model + def post(self, app_model): try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) @@ -120,9 +117,11 @@ def post(self, app_id): class TextModesApi(Resource): - def get(self, app_id: str): - app_model = _get_app(str(app_id)) - + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): try: parser = reqparse.RequestParser() parser.add_argument('language', type=str, required=True, location='args') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index f01d2afa031699..a7fd0164d86f73 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -10,7 +10,6 @@ import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -19,14 +18,16 @@ ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from libs.login import login_required +from models.model import AppMode from services.completion_service import CompletionService @@ -36,12 +37,8 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.COMPLETION) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') @@ -62,8 +59,7 @@ def post(self, app_id): user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming, - is_model_config_override=True + streaming=streaming ) return compact_response(response) @@ -93,15 +89,11 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.COMPLETION) + def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 @@ -110,12 +102,8 @@ class ChatMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') @@ -137,8 +125,7 @@ def post(self, app_id): user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming, - is_model_config_override=True + streaming=streaming ) return compact_response(response) @@ -179,15 +166,11 @@ class ChatMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'chat') - + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 452b0fddf6c014..11dece3a9e58e5 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db @@ -21,7 +21,7 @@ ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation +from models.model import AppMode, Conversation, Message, MessageAnnotation class CompletionConversationApi(Resource): @@ -29,10 +29,9 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -43,10 +42,7 @@ def get(self, app_id): parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'completion') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') if args['keyword']: query = query.join( @@ -106,24 +102,22 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'completion') + return _get_conversation(app_model, conversation_id) @setup_required @login_required @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -139,10 +133,9 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -154,10 +147,7 @@ def get(self, app_id): parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'chat') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'chat') if args['keyword']: query = query.join( @@ -228,25 +218,22 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'chat') + return _get_conversation(app_model, conversation_id) @setup_required @login_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - # get app info - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -263,12 +250,9 @@ def delete(self, app_id, conversation_id): api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') -def _get_conversation(app_id, conversation_id, mode): - # get app info - app = _get_app(app_id, mode) - +def _get_conversation(app_model, conversation_id): conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index d7b31906c8de21..b1abb38248f5a2 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException): error_code = 'too_many_files' description = "Only one file is allowed." code = 400 + + +class DraftWorkflowNotExist(BaseHTTPException): + error_code = 'draft_workflow_not_exist' + description = "Draft workflow need to be initialized." + code = 400 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 3ec932b5f11ead..ee02fc18465c5d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -11,7 +11,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 0064dbe663b90f..56d2e718e7d4d1 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -10,17 +10,15 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -28,10 +26,8 @@ from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService @@ -46,14 +42,10 @@ class ChatMessageListApi(Resource): @setup_required @login_required + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id, 'chat') - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args') @@ -62,7 +54,7 @@ def get(self, app_id): conversation = db.session.query(Conversation).filter( Conversation.id == args['conversation_id'], - Conversation.app_id == app.id + Conversation.app_id == app_model.id ).first() if not conversation: @@ -110,12 +102,8 @@ class MessageFeedbackApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('message_id', required=True, type=uuid_value, location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') @@ -125,7 +113,7 @@ def post(self, app_id): message = db.session.query(Message).filter( Message.id == message_id, - Message.app_id == app.id + Message.app_id == app_model.id ).first() if not message: @@ -141,7 +129,7 @@ def post(self, app_id): raise ValueError('rating cannot be None when feedback not exists') else: feedback = MessageFeedback( - app_id=app.id, + app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=args['rating'], @@ -160,21 +148,20 @@ class MessageAnnotationApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check('annotation') + @get_app_model @marshal_with(annotation_fields) - def post(self, app_id): + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - app_id = str(app_id) - parser = reqparse.RequestParser() parser.add_argument('message_id', required=False, type=uuid_value, location='json') parser.add_argument('question', required=True, type=str, location='json') parser.add_argument('answer', required=True, type=str, location='json') parser.add_argument('annotation_reply', required=False, type=dict, location='json') args = parser.parse_args() - annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) return annotation @@ -183,65 +170,15 @@ class MessageAnnotationCountApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def get(self, app_model): count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app.id + MessageAnnotation.app_id == app_model.id ).count() return {'count': count} -class MessageMoreLikeThisApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], - location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - # get app info - app_model = _get_app(app_id, 'completion') - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -257,13 +194,10 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id, 'chat') - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, @@ -294,14 +228,11 @@ class MessageApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(message_detail_fields) - def get(self, app_id, message_id): - app_id = str(app_id) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id) - message = db.session.query(Message).filter( Message.id == message_id, Message.app_id == app_model.id @@ -313,7 +244,6 @@ def get(self, app_id, message_id): return message -api.add_resource(MessageMoreLikeThisApi, '/apps//completion-messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') api.add_resource(MessageFeedbackApi, '/apps//feedbacks') diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 2095bb6bea4c2f..41b7151ba65013 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -5,16 +5,16 @@ from flask_restful import Resource from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AgentToolEntity +from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required -from models.model import AppModelConfig +from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService @@ -23,118 +23,115 @@ class ModelConfigResource(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) + def post(self, app_model): """Modify app model config""" - app_id = str(app_id) - - app = _get_app(app_id) - # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, - account=current_user, config=request.json, - app_mode=app.mode + app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( - app_id=app.id, + app_id=app_model.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app.app_model_config_id - ).first() - agent_mode = original_app_model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - parameter_map = {} - masked_parameter_map = {} - tool_map = {} - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - except Exception as e: - continue - - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - parameters = {} - masked_parameter = {} - - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - masked_parameter_map[key] = masked_parameter - parameter_map[key] = parameters - tool_map[key] = tool_runtime - - # encrypt agent tool parameters if it's secret-input - agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: - agent_tool_entity = AgentToolEntity(**tool) - - # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - if key in tool_map: - tool_runtime = tool_map[key] - else: + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + # get original app model config + original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( + AppModelConfig.id == app_model.app_model_config_id + ).first() + agent_mode = original_app_model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + parameter_map = {} + masked_parameter_map = {} + tool_map = {} + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + + agent_tool_entity = AgentToolEntity(**tool) + # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, agent_callback=None ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) except Exception as e: continue - - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - manager.delete_tool_parameters_cache() - - # override parameters if it equals to masked parameters - if agent_tool_entity.tool_parameters: - if key not in masked_parameter_map: - continue - if agent_tool_entity.tool_parameters == masked_parameter_map[key]: - agent_tool_entity.tool_parameters = parameter_map[key] + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + parameters = {} + masked_parameter = {} + + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + masked_parameter_map[key] = masked_parameter + parameter_map[key] = parameters + tool_map[key] = tool_runtime + + # encrypt agent tool parameters if it's secret-input + agent_mode = new_app_model_config.agent_mode_dict + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + + # get tool + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + if key in tool_map: + tool_runtime = tool_map[key] + else: + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + except Exception as e: + continue + + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() + + # override parameters if it equals to masked parameters + if agent_tool_entity.tool_parameters: + if key not in masked_parameter_map: + continue + + if agent_tool_entity.tool_parameters == masked_parameter_map[key]: + agent_tool_entity.tool_parameters = parameter_map[key] - # encrypt parameters - if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + # encrypt parameters + if agent_tool_entity.tool_parameters: + tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - # update app model config - new_app_model_config.agent_mode = json.dumps(agent_mode) + # update app model config + new_app_model_config.agent_mode = json.dumps(agent_mode) db.session.add(new_app_model_config) db.session.flush() - app.app_model_config_id = new_app_model_config.id + app_model.app_model_config_id = new_app_model_config.id db.session.commit() app_model_config_was_updated.send( - app, + app_model, app_model_config=new_app_model_config ) diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 4e9d9ed9b45682..256824981e6c72 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -4,7 +4,7 @@ from constants.languages import supported_language from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db @@ -34,13 +34,11 @@ class AppSite(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): + def post(self, app_model): args = parse_app_site_args() - app_id = str(app_id) - app_model = _get_app(app_id) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() @@ -82,11 +80,9 @@ class AppSiteAccessTokenReset(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id) - + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 7aed7da404aba7..d687b52dc8e6ec 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -7,12 +7,13 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required +from models.model import AppMode class DailyConversationStatistic(Resource): @@ -20,10 +21,9 @@ class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -81,10 +81,9 @@ class DailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -141,10 +140,9 @@ class DailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -205,10 +203,9 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -271,10 +268,9 @@ class UserSatisfactionRateStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -334,10 +330,9 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.COMPLETION) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'completion') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -396,10 +391,9 @@ class TokensPerSecondStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py new file mode 100644 index 00000000000000..d5967dd5ed31ed --- /dev/null +++ b/api/controllers/console/app/workflow.py @@ -0,0 +1,306 @@ +import json +import logging +from collections.abc import Generator +from typing import Union + +from flask import Response, stream_with_context +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import InternalServerError, NotFound + +import services +from controllers.console import api +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.app.entities.app_invoke_entities import InvokeFrom +from fields.workflow_fields import workflow_fields +from fields.workflow_run_fields import workflow_run_node_execution_fields +from libs.helper import TimestampField, uuid_value +from libs.login import current_user, login_required +from models.model import App, AppMode +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + + +class DraftWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise DraftWorkflowNotExist() + + # return workflow, if not found, return None (initiate graph by frontend) + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Sync draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') + parser.add_argument('features', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + workflow = workflow_service.sync_draft_workflow( + app_model=app_model, + graph=args.get('graph'), + features=args.get('features'), + account=current_user + ) + + return { + "result": "success", + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) + } + + +class AdvancedChatDraftWorkflowRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + def post(self, app_model: App): + """ + Run draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, location='json') + parser.add_argument('query', type=str, required=True, location='json', default='') + parser.add_argument('files', type=list, location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + try: + response = workflow_service.run_advanced_chat_draft_workflow( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class DraftWorkflowRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Run draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + + try: + response = workflow_service.run_draft_workflow( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER + ) + + return compact_response(response) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class WorkflowTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, task_id: str): + """ + Stop workflow task + """ + workflow_service = WorkflowService() + workflow_service.stop_workflow_task( + task_id=task_id, + user=current_user, + invoke_from=InvokeFrom.DEBUGGER + ) + + return { + "result": "success" + } + + +class DraftWorkflowNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow node + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( + app_model=app_model, + node_id=node_id, + user_inputs=args.get('inputs'), + account=current_user + ) + + return workflow_node_execution + + +class PublishedWorkflowApi(Resource): + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get published workflow + """ + # fetch published workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_published_workflow(app_model=app_model) + + # return workflow, if not found, return None + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Publish workflow + """ + workflow_service = WorkflowService() + workflow_service.publish_workflow(app_model=app_model, account=current_user) + + return { + "result": "success" + } + + +class DefaultBlockConfigsApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + """ + Get default block config + """ + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_configs() + + +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, block_type: str): + """ + Get default block config + """ + parser = reqparse.RequestParser() + parser.add_argument('q', type=str, location='args') + args = parser.parse_args() + + filters = None + if args.get('q'): + try: + filters = json.loads(args.get('q')) + except json.JSONDecodeError: + raise ValueError('Invalid filters') + + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_config( + node_type=block_type, + filters=filters + ) + + +class ConvertToWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) + def post(self, app_model: App): + """ + Convert basic mode of chatbot app to workflow mode + Convert expert mode of chatbot app to workflow mode + Convert Completion App to Workflow App + """ + # convert to workflow mode + workflow_service = WorkflowService() + workflow = workflow_service.convert_to_workflow( + app_model=app_model, + account=current_user + ) + + # return workflow + return workflow + + +def compact_response(response: Union[dict, Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') +api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') +api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') +api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') +api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') +api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' + '/') +api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py new file mode 100644 index 00000000000000..6d1709ed8e65d9 --- /dev/null +++ b/api/controllers/console/app/workflow_app_log.py @@ -0,0 +1,41 @@ +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.workflow_app_log_fields import workflow_app_log_pagination_fields +from libs.login import login_required +from models.model import App, AppMode +from services.workflow_app_service import WorkflowAppService + + +class WorkflowAppLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_app_log_pagination_fields) + def get(self, app_model: App): + """ + Get workflow app logs + """ + parser = reqparse.RequestParser() + parser.add_argument('keyword', type=str, location='args') + parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') + parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') + parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + args = parser.parse_args() + + # get paginate workflow app logs + workflow_app_service = WorkflowAppService() + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( + app_model=app_model, + args=args + ) + + return workflow_app_log_pagination + + +api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py new file mode 100644 index 00000000000000..8a4c0492a1551f --- /dev/null +++ b/api/controllers/console/app/workflow_run.py @@ -0,0 +1,83 @@ +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.workflow_run_fields import ( + workflow_run_detail_fields, + workflow_run_node_execution_list_fields, + workflow_run_pagination_fields, +) +from libs.helper import uuid_value +from libs.login import login_required +from models.model import App, AppMode +from services.workflow_run_service import WorkflowRunService + + +class WorkflowRunListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_pagination_fields) + def get(self, app_model: App): + """ + Get workflow run list + """ + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_paginate_workflow_runs( + app_model=app_model, + args=args + ) + + return result + + +class WorkflowRunDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_detail_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + + return workflow_run + + +class WorkflowRunNodeExecutionListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_list_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run node execution list + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) + + return { + 'data': node_executions + } + + +api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') +api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') +api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py new file mode 100644 index 00000000000000..d61ab6d6ae8f28 --- /dev/null +++ b/api/controllers/console/app/wraps.py @@ -0,0 +1,55 @@ +from collections.abc import Callable +from functools import wraps +from typing import Optional, Union + +from controllers.console.app.error import AppNotFoundError +from extensions.ext_database import db +from libs.login import current_user +from models.model import App, AppMode + + +def get_app_model(view: Optional[Callable] = None, *, + mode: Union[AppMode, list[AppMode]] = None): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): + if not kwargs.get('app_id'): + raise ValueError('missing app_id in path parameters') + + app_id = kwargs.get('app_id') + app_id = str(app_id) + + del kwargs['app_id'] + + app_model = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app_model: + raise AppNotFoundError() + + app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.CHANNEL: + raise AppNotFoundError() + + if mode is not None: + if isinstance(mode, list): + modes = mode + else: + modes = [mode] + + if app_mode not in modes: + mode_values = {m.value for m in modes} + raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + + kwargs['app_model'] = app_model + + return view_func(*args, **kwargs) + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index dc546ce0dd28d7..f03663f1a22ea3 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,6 @@ from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import AppModelConfig from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -32,16 +31,12 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=None ) @@ -76,16 +71,12 @@ def post(self, installed_app): class ChatTextApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.text_to_speech_dict['enabled']: - raise AppUnavailableError() try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) return {'data': response.data.decode('latin1')} diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 6406d5b3b05f6e..b8a5be0df0768a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -21,8 +21,8 @@ ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -90,7 +90,7 @@ def post(self, installed_app, task_id): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 @@ -154,7 +154,7 @@ def post(self, installed_app, task_id): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 920d9141ae1189..7d6231270f23d8 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -34,8 +34,7 @@ def get(self): 'is_pinned': installed_app.is_pinned, 'last_used_at': installed_app.last_used_at, 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id, - 'is_agent': installed_app.is_agent + 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id } for installed_app in installed_apps ] diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425fa896..fdb0eae24f00d4 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -24,7 +24,7 @@ NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c4afb0b9236651..9c0fca57f25a04 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -4,9 +4,10 @@ from flask_restful import fields, marshal_with from controllers.console import api +from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from models.model import AppModelConfig, InstalledApp +from models.model import AppMode, AppModelConfig, InstalledApp from models.tools import ApiToolProvider @@ -45,30 +46,55 @@ class AppParameterApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app - app_model_config = app_model.app_model_config + + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" app_model_config: AppModelConfig = installed_app.app.app_model_config + if not app_model_config: + return { + 'tool_icons': {} + } + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} @@ -77,7 +103,7 @@ def get(self, installed_app: InstalledApp): # get all tools tools = agent_config.get('tools', []) url_prefix = (current_app.config.get("CONSOLE_API_URL") - + "/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: @@ -94,12 +120,14 @@ def get(self, installed_app: InstalledApp): ) meta['tool_icons'][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { + meta['tool_icons'][tool_name] = { "background": "#252525", "content": "\ud83d\ude01" } return meta -api.add_resource(AppParameterApi, '/installed-apps//parameters', endpoint='installed_app_parameters') + +api.add_resource(AppParameterApi, '/installed-apps//parameters', + endpoint='installed_app_parameters') api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index fd90be03b16743..8190f7828dc755 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,15 +1,12 @@ from flask_login import current_user -from flask_restful import Resource, fields, marshal_with -from sqlalchemy import and_ +from flask_restful import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api from controllers.console.app.error import AppNotFoundError -from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.login import login_required -from models.model import App, InstalledApp, RecommendedApp -from services.account_service import TenantService +from models.model import App, RecommendedApp +from services.app_service import AppService app_fields = { 'id': fields.String, @@ -27,11 +24,7 @@ 'privacy_policy': fields.String, 'category': fields.String, 'position': fields.Integer, - 'is_listed': fields.Boolean, - 'install_count': fields.Integer, - 'installed': fields.Boolean, - 'editable': fields.Boolean, - 'is_agent': fields.Boolean + 'is_listed': fields.Boolean } recommended_app_list_fields = { @@ -41,11 +34,19 @@ class RecommendedAppListApi(Resource): - @login_required - @account_initialization_required @marshal_with(recommended_app_list_fields) def get(self): - language_prefix = current_user.interface_language if current_user.interface_language else languages[0] + # language args + parser = reqparse.RequestParser() + parser.add_argument('language', type=str, location='args') + args = parser.parse_args() + + if args.get('language') and args.get('language') in languages: + language_prefix = args.get('language') + elif current_user and current_user.interface_language: + language_prefix = current_user.interface_language + else: + language_prefix = languages[0] recommended_apps = db.session.query(RecommendedApp).filter( RecommendedApp.is_listed == True, @@ -53,16 +54,8 @@ def get(self): ).all() categories = set() - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) recommended_apps_result = [] for recommended_app in recommended_apps: - installed = db.session.query(InstalledApp).filter( - and_( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id == current_user.current_tenant_id - ) - ).first() is not None - app = recommended_app.app if not app or not app.is_public: continue @@ -80,11 +73,7 @@ def get(self): 'privacy_policy': site.privacy_policy, 'category': recommended_app.category, 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed, - 'install_count': recommended_app.install_count, - 'installed': installed, - 'editable': current_user.role in ['owner', 'admin'], - "is_agent": app.is_agent + 'is_listed': recommended_app.is_listed } recommended_apps_result.append(recommended_app_result) @@ -94,29 +83,6 @@ def get(self): class RecommendedAppApi(Resource): - model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - } - - app_simple_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'mode': fields.String, - 'app_model_config': fields.Nested(model_config_fields), - } - - @login_required - @account_initialization_required - @marshal_with(app_simple_detail_fields) def get(self, app_id): app_id = str(app_id) @@ -130,11 +96,21 @@ def get(self, app_id): raise AppNotFoundError # get app detail - app = db.session.query(App).filter(App.id == app_id).first() - if not app or not app.is_public: + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: raise AppNotFoundError - return app + app_service = AppService() + export_str = app_service.export_app(app_model) + + return { + 'id': app_model.id, + 'name': app_model.name, + 'icon': app_model.icon, + 'icon_background': app_model.icon_background, + 'mode': app_model.mode, + 'export_data': export_str + } api.add_resource(RecommendedAppListApi, '/explore/apps') diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py new file mode 100644 index 00000000000000..7664ba8c165db8 --- /dev/null +++ b/api/controllers/console/ping.py @@ -0,0 +1,17 @@ +from flask_restful import Resource + +from controllers.console import api + + +class PingApi(Resource): + + def get(self): + """ + For connection health check + """ + return { + "result": "pong" + } + + +api.add_resource(PingApi, '/ping') diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index b7cfba9d04c6b1..656a4d4cee6af5 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -16,26 +16,13 @@ ) from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField -} - class AccountInitApi(Resource): diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf57cd4b24c333..f40ccebf25496a 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,33 +1,18 @@ from flask import current_app from flask_login import current_user -from flask_restful import Resource, abort, fields, marshal_with, reqparse +from flask_restful import Resource, abort, marshal_with, reqparse import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from libs.helper import TimestampField +from fields.member_fields import account_with_role_list_fields from libs.login import login_required from models.account import Account from services.account_service import RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, -} - -account_list_fields = { - 'accounts': fields.List(fields.Nested(account_fields)) -} - class MemberListApi(Resource): """List all members of current tenant.""" @@ -35,7 +20,7 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_list_fields) + @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) return {'result': 'success', 'accounts': members}, 200 diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index a3151fc4a21ea5..76708716c226e4 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -4,9 +4,10 @@ from flask_restful import fields, marshal_with, Resource from controllers.service_api import api +from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, AppMode from models.tools import ApiToolProvider @@ -46,31 +47,55 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - app_model_config = app_model.app_model_config + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class AppMetaApi(Resource): @validate_app_token def get(self, app_model: App): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config + if not app_model_config: + return { + 'tool_icons': {} + } + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index f6cad501f09264..15c0a153b89283 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -20,7 +20,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig, EndUser +from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -33,18 +33,13 @@ class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() - file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, - end_user=end_user.get_id() + end_user=end_user ) return response @@ -79,15 +74,16 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('text', type=str, required=True, nullable=False, location='json') + parser.add_argument('voice', type=str, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') args = parser.parse_args() try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=args['text'], - end_user=end_user.get_id(), - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + end_user=end_user, + voice=args.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c6cfb24378c482..410fb5bffd8e4e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -19,8 +19,8 @@ ProviderQuotaExceededError, ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -85,7 +85,7 @@ def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'completion': raise AppUnavailableError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 @@ -147,7 +147,7 @@ def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 25492b11432a6e..07ce098298a364 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -4,9 +4,10 @@ from flask_restful import fields, marshal_with from controllers.web import api +from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, AppMode from models.tools import ApiToolProvider @@ -44,30 +45,52 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - app_model_config = app_model.app_model_config + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class AppMeta(WebApiResource): def get(self, app_model: App, end_user): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config + if not app_model_config: + raise AppUnavailableError() + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 4e677ae288dd4a..e0074c452fb851 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -19,7 +19,7 @@ from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig +from models.model import App from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -31,16 +31,11 @@ class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() - file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=end_user ) @@ -74,17 +69,12 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.text_to_speech_dict['enabled']: - raise AppUnavailableError() - try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], end_user=end_user.external_user_id, - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 61d4f8c36232ba..ed1378e7e3d85f 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -20,8 +20,8 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -84,7 +84,7 @@ def post(self, app_model, end_user, task_id): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 @@ -144,7 +144,7 @@ def post(self, app_model, end_user, task_id): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index e03bdd63bb2a27..1acb92dbf1eda5 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -21,7 +21,7 @@ ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index d8e2d597071889..bf3536d2766ed8 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -83,7 +83,3 @@ def __init__(self, tenant, app, site, end_user, can_replace_logo): 'remove_webapp_brand': remove_webapp_brand, 'replace_webapp_logo': replace_webapp_logo, } - - if app.enable_site and site.prompt_public: - app_model_config = app.app_model_config - self.model_config = app_model_config diff --git a/api/core/app_runner/__init__.py b/api/core/agent/__init__.py similarity index 100% rename from api/core/app_runner/__init__.py rename to api/core/agent/__init__.py diff --git a/api/core/features/assistant_base_runner.py b/api/core/agent/base_agent_runner.py similarity index 79% rename from api/core/features/assistant_base_runner.py rename to api/core/agent/base_agent_runner.py index 1d9541070f881f..14602a72656b39 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,21 +2,19 @@ import logging import uuid from datetime import datetime -from mimetypes import guess_extension from typing import Optional, Union, cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - AgentEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, +from core.agent.entities import AgentEntity, AgentToolEntity +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, InvokeFrom, - ModelConfigEntity, + ModelConfigWithCredentialsEntity, ) +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -40,7 +38,6 @@ ) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool -from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_manager import ToolManager from extensions.ext_database import db from models.model import Message, MessageAgentThought, MessageFile @@ -48,13 +45,13 @@ logger = logging.getLogger(__name__) -class BaseAssistantApplicationRunner(AppRunner): +class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, - application_generate_entity: ApplicationGenerateEntity, - app_orchestration_config: AppOrchestrationConfigEntity, - model_config: ModelConfigEntity, + application_generate_entity: AgentChatAppGenerateEntity, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, config: AgentEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, @@ -66,7 +63,7 @@ def __init__(self, tenant_id: str, """ Agent runner :param tenant_id: tenant id - :param app_orchestration_config: app orchestration config + :param app_config: app generate entity :param model_config: model config :param config: dataset config :param queue_manager: queue manager @@ -78,7 +75,7 @@ def __init__(self, tenant_id: str, """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity - self.app_orchestration_config = app_orchestration_config + self.app_config = app_config self.model_config = model_config self.config = config self.queue_manager = queue_manager @@ -97,16 +94,16 @@ def __init__(self, tenant_id: str, # init dataset tools hit_callback = DatasetIndexToolCallbackHandler( queue_manager=queue_manager, - app_id=self.application_generate_entity.app_id, + app_id=self.app_config.app_id, message_id=message.id, user_id=user_id, invoke_from=self.application_generate_entity.invoke_from, ) self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( tenant_id=tenant_id, - dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [], - retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None, - return_resource=app_orchestration_config.show_retrieve_source, + dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], + retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, + return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback ) @@ -124,14 +121,15 @@ def __init__(self, tenant_id: str, else: self.stream_tool_call = False - def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: + def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ + -> AgentChatAppGenerateEntity: """ - Repack app orchestration config + Repack app generate entity """ - if app_orchestration_config.prompt_template.simple_prompt_template is None: - app_orchestration_config.prompt_template.simple_prompt_template = '' + if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: + app_generate_entity.app_config.prompt_template.simple_prompt_template = '' - return app_orchestration_config + return app_generate_entity def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: """ @@ -351,7 +349,7 @@ def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[ )) db.session.close() - + return result def create_agent_thought(self, message_id: str, message: str, @@ -463,73 +461,6 @@ def save_agent_thought(self, db.session.commit() db.session.close() - def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: - """ - Transform tool message into agent thought - """ - result = [] - - for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.IMAGE: - # try to download image - try: - file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_url=message.message) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - except Exception as e: - logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') - # if message is str, encode it to bytes - if isinstance(message.message, str): - message.message = message.message.encode('utf-8') - file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_binary=message.message, - mimetype=mimetype) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' - - # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(message) - - return result - def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ convert tool variables to db variables diff --git a/api/core/features/assistant_cot_runner.py b/api/core/agent/cot_agent_runner.py similarity index 89% rename from api/core/features/assistant_cot_runner.py rename to api/core/agent/cot_agent_runner.py index 3762ddcf62e7c5..0c5399f5416d65 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -3,9 +3,10 @@ from collections.abc import Generator from typing import Literal, Union -from core.application_queue_manager import PublishFrom -from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -24,10 +25,11 @@ ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message -class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): +class CotAgentRunner(BaseAgentRunner): _is_first_iteration = True _ignore_observation_providers = ['wenxin'] @@ -39,30 +41,33 @@ def run(self, conversation: Conversation, """ Run Cot agent application """ - app_orchestration_config = self.app_orchestration_config - self._repack_app_orchestration_config(app_orchestration_config) + app_generate_entity = self.application_generate_entity + self._repack_app_generate_entity(app_generate_entity) agent_scratchpad: list[AgentScratchpadUnit] = [] self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) - if 'Observation' not in app_orchestration_config.model_config.stop: - if app_orchestration_config.model_config.provider not in self._ignore_observation_providers: - app_orchestration_config.model_config.stop.append('Observation') + # check model mode + if 'Observation' not in app_generate_entity.model_config.stop: + if app_generate_entity.model_config.provider not in self._ignore_observation_providers: + app_generate_entity.model_config.stop.append('Observation') + + app_config = self.app_config # override inputs inputs = inputs or {} - instruction = self.app_orchestration_config.prompt_template.simple_prompt_template + instruction = app_config.prompt_template.simple_prompt_template instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 - max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} - for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + for tool in app_config.agent.tools if app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -118,15 +123,17 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ) if iteration_step > 1: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt messages prompt_messages = self._organize_cot_prompt_messages( - mode=app_orchestration_config.model_config.mode, + mode=app_generate_entity.model_config.mode, prompt_messages=prompt_messages, tools=prompt_messages_tools, agent_scratchpad=agent_scratchpad, - agent_prompt_message=app_orchestration_config.agent.prompt, + agent_prompt_message=app_config.agent.prompt, instruction=instruction, input=query ) @@ -136,9 +143,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=app_generate_entity.model_config.parameters, tools=[], - stop=app_orchestration_config.model_config.stop, + stop=app_generate_entity.model_config.stop, stream=True, user=self.user_id, callbacks=[], @@ -160,7 +167,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) for chunk in react_chunks: if isinstance(chunk, dict): @@ -222,7 +231,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): llm_usage=usage_dict['usage']) if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) if not scratchpad.action: # failed to extract action, return final answer directly @@ -252,7 +263,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): observation=answer, answer=answer, messages_ids=[]) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) else: # invoke tool error_response = None @@ -262,13 +275,18 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_call_args = json.loads(tool_call_args) except json.JSONDecodeError: pass - + tool_response = tool_instance.invoke( user_id=self.user_id, tool_parameters=tool_call_args ) # transform tool response to llm friendly response - tool_response = self.transform_tool_invoke_messages(tool_response) + tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_response, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_response) # create message file @@ -279,7 +297,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) message_file_ids = [message_file.id for message_file, _ in message_files] except ToolProviderCredentialValidationError as e: @@ -315,7 +335,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=scratchpad.agent_response, messages_ids=message_file_ids, ) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool message for prompt_tool in prompt_messages_tools: @@ -349,7 +371,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish_message_end(LLMResult( + self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( @@ -357,7 +379,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' - ), PublishFrom.APPLICATION_MANAGER) + )), PublishFrom.APPLICATION_MANAGER) def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ -> Generator[Union[str, dict], None, None]: @@ -550,7 +572,7 @@ def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpad """ convert agent scratchpad list to str """ - next_iteration = self.app_orchestration_config.agent.prompt.next_iteration + next_iteration = self.app_config.agent.prompt.next_iteration result = '' for scratchpad in agent_scratchpad: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py new file mode 100644 index 00000000000000..e7016d6030cc20 --- /dev/null +++ b/api/core/agent/entities.py @@ -0,0 +1,61 @@ +from enum import Enum +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + + +class AgentToolEntity(BaseModel): + """ + Agent Tool Entity. + """ + provider_type: Literal["builtin", "api"] + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] = {} + + +class AgentPromptEntity(BaseModel): + """ + Agent Prompt Entity. + """ + first_prompt: str + next_iteration: str + + +class AgentScratchpadUnit(BaseModel): + """ + Agent First Prompt Entity. + """ + + class Action(BaseModel): + """ + Action Entity. + """ + action_name: str + action_input: Union[dict, str] + + agent_response: Optional[str] = None + thought: Optional[str] = None + action_str: Optional[str] = None + observation: Optional[str] = None + action: Optional[Action] = None + + +class AgentEntity(BaseModel): + """ + Agent Entity. + """ + + class Strategy(Enum): + """ + Agent Strategy. + """ + CHAIN_OF_THOUGHT = 'chain-of-thought' + FUNCTION_CALLING = 'function-calling' + + provider: str + model: str + strategy: Strategy + prompt: Optional[AgentPromptEntity] = None + tools: list[AgentToolEntity] = None + max_iteration: int = 5 diff --git a/api/core/features/assistant_fc_runner.py b/api/core/agent/fc_agent_runner.py similarity index 88% rename from api/core/features/assistant_fc_runner.py rename to api/core/agent/fc_agent_runner.py index 391e040c53d32b..185d7684c82ad4 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -3,8 +3,9 @@ from collections.abc import Generator from typing import Any, Union -from core.application_queue_manager import PublishFrom -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.agent.base_agent_runner import BaseAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -22,11 +23,12 @@ ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) -class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): +class FunctionCallAgentRunner(BaseAgentRunner): def run(self, conversation: Conversation, message: Message, query: str, @@ -34,9 +36,11 @@ def run(self, conversation: Conversation, """ Run FunctionCall agent application """ - app_orchestration_config = self.app_orchestration_config + app_generate_entity = self.application_generate_entity - prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or '' + app_config = self.app_config + + prompt_template = app_config.prompt_template.simple_prompt_template or '' prompt_messages = self.history_prompt_messages prompt_messages = self.organize_prompt_messages( prompt_template=prompt_template, @@ -47,7 +51,7 @@ def run(self, conversation: Conversation, # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} - for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + for tool in app_config.agent.tools if app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -67,7 +71,7 @@ def run(self, conversation: Conversation, tool_instances[dataset_tool.identity.name] = dataset_tool iteration_step = 1 - max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 # continue to run until there is not any tool call function_call_state = True @@ -110,9 +114,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=app_generate_entity.model_config.parameters, tools=prompt_messages_tools, - stop=app_orchestration_config.model_config.stop, + stop=app_generate_entity.model_config.stop, stream=self.stream_tool_call, user=self.user_id, callbacks=[], @@ -133,7 +137,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): @@ -193,7 +199,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not result.message.content: result.message.content = '' - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) yield LLMResultChunk( model=model_instance.model, @@ -231,8 +239,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): messages_ids=[], llm_usage=current_llm_usage ) - - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) final_answer += response + '\n' @@ -262,7 +271,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_parameters=tool_call_args, ) # transform tool invoke message to get LLM friendly message - tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message) + tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_invoke_message, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_invoke_message) # create message file @@ -273,7 +287,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) # publish message file - self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) # add message file ids message_file_ids.append(message_file.id) @@ -329,7 +345,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=None, messages_ids=message_file_ids ) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -339,15 +357,15 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish_message_end(LLMResult( + self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=final_answer, + content=final_answer ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' - ), PublishFrom.APPLICATION_MANAGER) + )), PublishFrom.APPLICATION_MANAGER) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ diff --git a/api/core/features/__init__.py b/api/core/app/__init__.py similarity index 100% rename from api/core/features/__init__.py rename to api/core/app/__init__.py diff --git a/api/core/features/dataset_retrieval/__init__.py b/api/core/app/app_config/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/__init__.py rename to api/core/app/app_config/__init__.py diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py new file mode 100644 index 00000000000000..e09aa0376685a1 --- /dev/null +++ b/api/core/app/app_config/base_app_config_manager.py @@ -0,0 +1,74 @@ +from typing import Optional, Union + +from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppModelConfig + + +class BaseAppConfigManager: + + @classmethod + def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: Union[AppModelConfig, dict], + config_dict: Optional[dict] = None) -> dict: + """ + Convert app model config to config dict + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + + return config_dict + + @classmethod + def convert_features(cls, config_dict: dict) -> AppAdditionalFeatures: + """ + Convert app config to app model config + + :param config_dict: app config + """ + config_dict = config_dict.copy() + + additional_features = AppAdditionalFeatures() + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( + config=config_dict + ) + + additional_features.file_upload = FileUploadConfigManager.convert( + config=config_dict + ) + + additional_features.opening_statement, additional_features.suggested_questions = \ + OpeningStatementConfigManager.convert( + config=config_dict + ) + + additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( + config=config_dict + ) + + additional_features.more_like_this = MoreLikeThisConfigManager.convert( + config=config_dict + ) + + additional_features.speech_to_text = SpeechToTextConfigManager.convert( + config=config_dict + ) + + additional_features.text_to_speech = TextToSpeechConfigManager.convert( + config=config_dict + ) + + return additional_features diff --git a/api/core/features/dataset_retrieval/agent/__init__.py b/api/core/app/app_config/common/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/__init__.py rename to api/core/app/app_config/common/__init__.py diff --git a/api/core/features/dataset_retrieval/agent/output_parser/__init__.py b/api/core/app/app_config/common/sensitive_word_avoidance/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/__init__.py rename to api/core/app/app_config/common/sensitive_word_avoidance/__init__.py diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py new file mode 100644 index 00000000000000..3dccfa3cbed2da --- /dev/null +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -0,0 +1,50 @@ +from typing import Optional + +from core.app.app_config.entities import SensitiveWordAvoidanceEntity +from core.moderation.factory import ModerationFactory + + +class SensitiveWordAvoidanceConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + if not sensitive_word_avoidance_dict: + return None + + if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: + return SensitiveWordAvoidanceEntity( + type=sensitive_word_avoidance_dict.get('type'), + config=sensitive_word_avoidance_dict.get('config'), + ) + else: + return None + + @classmethod + def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ + -> tuple[dict, list[str]]: + if not config.get("sensitive_word_avoidance"): + config["sensitive_word_avoidance"] = { + "enabled": False + } + + if not isinstance(config["sensitive_word_avoidance"], dict): + raise ValueError("sensitive_word_avoidance must be of dict type") + + if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: + config["sensitive_word_avoidance"]["enabled"] = False + + if config["sensitive_word_avoidance"]["enabled"]: + if not config["sensitive_word_avoidance"].get("type"): + raise ValueError("sensitive_word_avoidance.type is required") + + if not only_structure_validate: + typ = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["config"] + + ModerationFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/__init__.py b/api/core/app/app_config/easy_ui_based_app/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/__init__.py b/api/core/app/app_config/easy_ui_based_app/agent/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py new file mode 100644 index 00000000000000..b50b7f678c492d --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -0,0 +1,79 @@ +from typing import Optional + +from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity +from core.tools.prompt.template import REACT_PROMPT_TEMPLATES + + +class AgentConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[AgentEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + if 'agent_mode' in config and config['agent_mode'] \ + and 'enabled' in config['agent_mode'] \ + and config['agent_mode']['enabled']: + + agent_dict = config.get('agent_mode', {}) + agent_strategy = agent_dict.get('strategy', 'cot') + + if agent_strategy == 'function_call': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + elif agent_strategy == 'cot' or agent_strategy == 'react': + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + else: + # old configs, try to detect default strategy + if config['model']['provider'] == 'openai': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + + agent_tools = [] + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) >= 4: + if "enabled" not in tool or not tool["enabled"]: + continue + + agent_tool_properties = { + 'provider_type': tool['provider_type'], + 'provider_id': tool['provider_id'], + 'tool_name': tool['tool_name'], + 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} + } + + agent_tools.append(AgentToolEntity(**agent_tool_properties)) + + if 'strategy' in config['agent_mode'] and \ + config['agent_mode']['strategy'] not in ['react_router', 'router']: + agent_prompt = agent_dict.get('prompt', None) or {} + # check model mode + model_mode = config.get('model', {}).get('mode', 'completion') + if model_mode == 'completion': + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['completion'][ + 'agent_scratchpad']), + ) + else: + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + ) + + return AgentEntity( + provider=config['model']['provider'], + model=config['model']['name'], + strategy=strategy, + prompt=agent_prompt_entity, + tools=agent_tools, + max_iteration=agent_dict.get('max_iteration', 5) + ) + + return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/__init__.py b/api/core/app/app_config/easy_ui_based_app/dataset/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py new file mode 100644 index 00000000000000..c10aa98dbacf9e --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -0,0 +1,224 @@ +from typing import Optional + +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode +from services.dataset_service import DatasetService + + +class DatasetConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[DatasetEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + dataset_ids = [] + if 'datasets' in config.get('dataset_configs', {}): + datasets = config.get('dataset_configs', {}).get('datasets', { + 'strategy': 'router', + 'datasets': [] + }) + + for dataset in datasets.get('datasets', []): + keys = list(dataset.keys()) + if len(keys) == 0 or keys[0] != 'dataset': + continue + + dataset = dataset['dataset'] + + if 'enabled' not in dataset or not dataset['enabled']: + continue + + dataset_id = dataset.get('id', None) + if dataset_id: + dataset_ids.append(dataset_id) + + if 'agent_mode' in config and config['agent_mode'] \ + and 'enabled' in config['agent_mode'] \ + and config['agent_mode']['enabled']: + + agent_dict = config.get('agent_mode', {}) + + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) == 1: + # old standard + key = list(tool.keys())[0] + + if key != 'dataset': + continue + + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + continue + + dataset_id = tool_item['id'] + dataset_ids.append(dataset_id) + + if len(dataset_ids) == 0: + return None + + # dataset configs + dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'}) + query_variable = config.get('dataset_query_variable') + + if dataset_configs['retrieval_model'] == 'single': + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ) + ) + ) + else: + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ), + top_k=dataset_configs.get('top_k'), + score_threshold=dataset_configs.get('score_threshold'), + reranking_model=dataset_configs.get('reranking_model') + ) + ) + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for dataset feature + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) + + # dataset_configs + if not config.get("dataset_configs"): + config["dataset_configs"] = {'retrieval_model': 'single'} + + if not config["dataset_configs"].get("datasets"): + config["dataset_configs"]["datasets"] = { + "strategy": "router", + "datasets": [] + } + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + if config["dataset_configs"]['retrieval_model'] == 'multiple': + if not config["dataset_configs"]['reranking_model']: + raise ValueError("reranking_model has not been set") + if not isinstance(config["dataset_configs"]['reranking_model'], dict): + raise ValueError("reranking_model must be of object type") + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + need_manual_query_datasets = (config.get("dataset_configs") + and config["dataset_configs"].get("datasets", {}).get("datasets")) + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] + + @classmethod + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: + """ + Extract dataset config for legacy compatibility + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + # enabled + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + # tools + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + # strategy + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + has_datasets = False + if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key == "dataset": + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not cls.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + + has_datasets = True + + need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config + + @classmethod + def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool: + # verify if the dataset ID exists + dataset = DatasetService.get_dataset(dataset_id) + + if not dataset: + return False + + if dataset.tenant_id != tenant_id: + return False + + return True diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py b/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py new file mode 100644 index 00000000000000..5c9b2cfec7babf --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -0,0 +1,103 @@ +from typing import cast + +from core.app.app_config.entities import EasyUIBasedAppConfig +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.provider_manager import ProviderManager + + +class ModelConfigConverter: + @classmethod + def convert(cls, app_config: EasyUIBasedAppConfig, + skip_check: bool = False) \ + -> ModelConfigWithCredentialsEntity: + """ + Convert app model config dict to entity. + :param app_config: app config + :param skip_check: skip check + :raises ProviderTokenNotInitError: provider token not init error + :return: app orchestration config entity + """ + model_config = app_config.model + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=app_config.tenant_id, + provider=model_config.provider, + model_type=ModelType.LLM + ) + + provider_name = provider_model_bundle.configuration.provider.provider + model_name = model_config.model + + model_type_instance = provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + # check model credentials + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, + model=model_config.model + ) + + if model_credentials is None: + if not skip_check: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + else: + model_credentials = {} + + if not skip_check: + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, + model_type=ModelType.LLM + ) + + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = model_config.parameters + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = model_config.mode + if not model_mode: + mode_enum = model_type_instance.get_model_mode( + model=model_config.model, + credentials=model_credentials + ) + + model_mode = mode_enum.value + + model_schema = model_type_instance.get_model_schema( + model_config.model, + model_credentials + ) + + if not skip_check and not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return ModelConfigWithCredentialsEntity( + provider=model_config.provider, + model=model_config.model, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py new file mode 100644 index 00000000000000..730a9527cf7315 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -0,0 +1,112 @@ +from core.app.app_config.entities import ModelConfigEntity +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager + + +class ModelConfigManager: + @classmethod + def convert(cls, config: dict) -> ModelConfigEntity: + """ + Convert model config to model config + + :param config: model config args + """ + # model config + model_config = config.get('model') + + if not model_config: + raise ValueError("model is required") + + completion_params = model_config.get('completion_params') + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = model_config.get('mode') + + return ModelConfigEntity( + provider=config['model']['provider'], + model=config['model']['name'], + mode=model_mode, + parameters=completion_params, + stop=stop, + ) + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for model config + + :param tenant_id: tenant id + :param config: app model config args + """ + if 'model' not in config: + raise ValueError("model is required") + + if not isinstance(config["model"], dict): + raise ValueError("model must be of object type") + + # model.provider + provider_entities = model_provider_factory.get_providers() + model_provider_names = [provider.provider for provider in provider_entities] + if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") + + # model.name + if 'name' not in config["model"]: + raise ValueError("model.name is required") + + provider_manager = ProviderManager() + models = provider_manager.get_configurations(tenant_id).get_models( + provider=config["model"]["provider"], + model_type=ModelType.LLM + ) + + if not models: + raise ValueError("model.name must be in the specified model list") + + model_ids = [m.model for m in models] + if config["model"]["name"] not in model_ids: + raise ValueError("model.name must be in the specified model list") + + model_mode = None + for model in models: + if model.model == config["model"]["name"]: + model_mode = model.model_properties.get(ModelPropertyKey.MODE) + break + + # model.mode + if model_mode: + config['model']["mode"] = model_mode + else: + config['model']["mode"] = "completion" + + # model.completion_params + if 'completion_params' not in config["model"]: + raise ValueError("model.completion_params is required") + + config["model"]["completion_params"] = cls.validate_model_completion_params( + config["model"]["completion_params"] + ) + + return config, ["model"] + + @classmethod + def validate_model_completion_params(cls, cp: dict) -> dict: + # model.completion_params + if not isinstance(cp, dict): + raise ValueError("model.completion_params must be of object type") + + # stop + if 'stop' not in cp: + cp["stop"] = [] + elif not isinstance(cp["stop"], list): + raise ValueError("stop in model.completion_params must be of list type") + + if len(cp["stop"]) > 4: + raise ValueError("stop sequences must be less than 4") + + return cp diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py new file mode 100644 index 00000000000000..1f410758aa41da --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -0,0 +1,140 @@ +from core.app.app_config.entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.prompt.simple_prompt_transform import ModelMode +from models.model import AppMode + + +class PromptTemplateConfigManager: + @classmethod + def convert(cls, config: dict) -> PromptTemplateEntity: + if not config.get("prompt_type"): + raise ValueError("prompt_type is required") + + prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + simple_prompt_template = config.get("pre_prompt", "") + return PromptTemplateEntity( + prompt_type=prompt_type, + simple_prompt_template=simple_prompt_template + ) + else: + advanced_chat_prompt_template = None + chat_prompt_config = config.get("chat_prompt_config", {}) + if chat_prompt_config: + chat_prompt_messages = [] + for message in chat_prompt_config.get("prompt", []): + chat_prompt_messages.append({ + "text": message["text"], + "role": PromptMessageRole.value_of(message["role"]) + }) + + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( + messages=chat_prompt_messages + ) + + advanced_completion_prompt_template = None + completion_prompt_config = config.get("completion_prompt_config", {}) + if completion_prompt_config: + completion_prompt_template_params = { + 'prompt': completion_prompt_config['prompt']['text'], + } + + if 'conversation_histories_role' in completion_prompt_config: + completion_prompt_template_params['role_prefix'] = { + 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], + 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + } + + advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( + **completion_prompt_template_params + ) + + return PromptTemplateEntity( + prompt_type=prompt_type, + advanced_chat_prompt_template=advanced_chat_prompt_template, + advanced_completion_prompt_template=advanced_completion_prompt_template + ) + + @classmethod + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + """ + Validate pre_prompt and set defaults for prompt feature + depending on the config['model'] + + :param app_mode: app mode + :param config: app model config args + """ + if not config.get("prompt_type"): + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + + prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + if config['prompt_type'] not in prompt_type_vals: + raise ValueError(f"prompt_type must be in {prompt_type_vals}") + + # chat_prompt_config + if not config.get("chat_prompt_config"): + config["chat_prompt_config"] = {} + + if not isinstance(config["chat_prompt_config"], dict): + raise ValueError("chat_prompt_config must be of object type") + + # completion_prompt_config + if not config.get("completion_prompt_config"): + config["completion_prompt_config"] = {} + + if not isinstance(config["completion_prompt_config"], dict): + raise ValueError("completion_prompt_config must be of object type") + + if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config['chat_prompt_config'] and not config['completion_prompt_config']: + raise ValueError("chat_prompt_config or completion_prompt_config is required " + "when prompt_type is advanced") + + model_mode_vals = [mode.value for mode in ModelMode] + if config['model']["mode"] not in model_mode_vals: + raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") + + if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: + user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] + assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + + if not user_prefix: + config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + + if not assistant_prefix: + config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + + if config['model']["mode"] == ModelMode.CHAT.value: + prompt_list = config['chat_prompt_config']['prompt'] + + if len(prompt_list) > 10: + raise ValueError("prompt messages must be less than 10") + else: + # pre_prompt, for simple mode + if not config.get("pre_prompt"): + config["pre_prompt"] = "" + + if not isinstance(config["pre_prompt"], str): + raise ValueError("pre_prompt must be of string type") + + return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] + + @classmethod + def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: + """ + Validate post_prompt and set defaults for prompt feature + + :param config: app model config args + """ + # post_prompt + if not config.get("post_prompt"): + config["post_prompt"] = "" + + if not isinstance(config["post_prompt"], str): + raise ValueError("post_prompt must be of string type") + + return config diff --git a/api/core/app/app_config/easy_ui_based_app/variables/__init__.py b/api/core/app/app_config/easy_ui_based_app/variables/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py new file mode 100644 index 00000000000000..1237da502b4258 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -0,0 +1,183 @@ +import re + +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity +from core.external_data_tool.factory import ExternalDataToolFactory + + +class BasicVariablesConfigManager: + @classmethod + def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + """ + Convert model config to model config + + :param config: model config args + """ + external_data_variables = [] + variables = [] + + # old external_data_tools + external_data_tools = config.get('external_data_tools', []) + for external_data_tool in external_data_tools: + if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + continue + + external_data_variables.append( + ExternalDataVariableEntity( + variable=external_data_tool['variable'], + type=external_data_tool['type'], + config=external_data_tool['config'] + ) + ) + + # variables and external_data_tools + for variable in config.get('user_input_form', []): + typ = list(variable.keys())[0] + if typ == 'external_data_tool': + val = variable[typ] + external_data_variables.append( + ExternalDataVariableEntity( + variable=val['variable'], + type=val['type'], + config=val['config'] + ) + ) + elif typ in [ + VariableEntity.Type.TEXT_INPUT.value, + VariableEntity.Type.PARAGRAPH.value, + VariableEntity.Type.NUMBER.value, + ]: + variables.append( + VariableEntity( + type=VariableEntity.Type.value_of(typ), + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + max_length=variable[typ].get('max_length'), + default=variable[typ].get('default'), + ) + ) + elif typ == VariableEntity.Type.SELECT.value: + variables.append( + VariableEntity( + type=VariableEntity.Type.SELECT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + options=variable[typ].get('options'), + default=variable[typ].get('default'), + ) + ) + + return variables, external_data_variables + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param tenant_id: workspace id + :param config: app model config args + """ + related_config_keys = [] + config, current_related_config_keys = cls.validate_variables_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + return config, related_config_keys + + @classmethod + def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param config: app model config args + """ + if not config.get("user_input_form"): + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + + form_item = item[key] + if 'label' not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if 'variable' not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, " + "and cannot start with a number") + + variables.append(form_item["variable"]) + + if 'required' not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if 'options' not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item['default'] \ + and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + return config, ["user_input_form"] + + @classmethod + def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for external data fetch feature + + :param tenant_id: workspace id + :param config: app model config args + """ + if not config.get("external_data_tools"): + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + typ = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["external_data_tools"] \ No newline at end of file diff --git a/api/core/entities/application_entities.py b/api/core/app/app_config/entities.py similarity index 56% rename from api/core/entities/application_entities.py rename to api/core/app/app_config/entities.py index abcf605c92d961..6a521dfcc5b7b5 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/app/app_config/entities.py @@ -1,12 +1,10 @@ from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Optional from pydantic import BaseModel -from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import AIModelEntity +from models.model import AppMode class ModelConfigEntity(BaseModel): @@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel): """ provider: str model: str - model_schema: AIModelEntity - mode: str - provider_model_bundle: ProviderModelBundle - credentials: dict[str, Any] = {} + mode: Optional[str] = None parameters: dict[str, Any] = {} stop: list[str] = [] @@ -86,6 +81,40 @@ def value_of(cls, value: str) -> 'PromptType': advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntity(BaseModel): + """ + Variable Entity. + """ + class Type(Enum): + TEXT_INPUT = 'text-input' + SELECT = 'select' + PARAGRAPH = 'paragraph' + NUMBER = 'number' + + @classmethod + def value_of(cls, value: str) -> 'VariableEntity.Type': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid variable type value {value}') + + variable: str + label: str + description: Optional[str] = None + type: Type + required: bool = False + max_length: Optional[int] = None + options: Optional[list[str]] = None + default: Optional[str] = None + hint: Optional[str] = None + + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. @@ -124,7 +153,6 @@ def value_of(cls, value: str) -> 'RetrieveStrategy': query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - single_strategy: Optional[str] = None # for temp top_k: Optional[int] = None score_threshold: Optional[float] = None reranking_model: Optional[dict] = None @@ -162,148 +190,53 @@ class FileUploadEntity(BaseModel): image_config: Optional[dict[str, Any]] = None -class AgentToolEntity(BaseModel): - """ - Agent Tool Entity. - """ - provider_type: Literal["builtin", "api"] - provider_id: str - tool_name: str - tool_parameters: dict[str, Any] = {} - - -class AgentPromptEntity(BaseModel): - """ - Agent Prompt Entity. - """ - first_prompt: str - next_iteration: str - - -class AgentScratchpadUnit(BaseModel): - """ - Agent First Prompt Entity. - """ - - class Action(BaseModel): - """ - Action Entity. - """ - action_name: str - action_input: Union[dict, str] - - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None - - -class AgentEntity(BaseModel): - """ - Agent Entity. - """ - - class Strategy(Enum): - """ - Agent Strategy. - """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' - - provider: str - model: str - strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None - max_iteration: int = 5 - - -class AppOrchestrationConfigEntity(BaseModel): - """ - App Orchestration Config Entity. - """ - model_config: ModelConfigEntity - prompt_template: PromptTemplateEntity - external_data_variables: list[ExternalDataVariableEntity] = [] - agent: Optional[AgentEntity] = None - - # features - dataset: Optional[DatasetEntity] = None +class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileUploadEntity] = None opening_statement: Optional[str] = None + suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: dict = {} - sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None + text_to_speech: Optional[TextToSpeechEntity] = None -class InvokeFrom(Enum): +class AppConfig(BaseModel): """ - Invoke From. + Application Config Entity. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' - - @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid invoke from value {value}') + tenant_id: str + app_id: str + app_mode: AppMode + additional_features: AppAdditionalFeatures + variables: list[VariableEntity] = [] + sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None - def to_source(self) -> str: - """ - Get source of invoke from. - :return: source - """ - if self == InvokeFrom.WEB_APP: - return 'web_app' - elif self == InvokeFrom.DEBUGGER: - return 'dev' - elif self == InvokeFrom.EXPLORE: - return 'explore_app' - elif self == InvokeFrom.SERVICE_API: - return 'api' - - return 'dev' +class EasyUIBasedAppModelConfigFrom(Enum): + """ + App Model Config From. + """ + ARGS = 'args' + APP_LATEST_CONFIG = 'app-latest-config' + CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' -class ApplicationGenerateEntity(BaseModel): +class EasyUIBasedAppConfig(AppConfig): """ - Application Generate Entity. + Easy UI Based App Config Entity. """ - task_id: str - tenant_id: str - - app_id: str + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - # for save app_model_config_dict: dict - app_model_config_override: bool - - # Converted from app_model_config to Entity object, or directly covered by external input - app_orchestration_config_entity: AppOrchestrationConfigEntity - - conversation_id: Optional[str] = None - inputs: dict[str, str] - query: Optional[str] = None - files: list[FileObj] = [] - user_id: str - # extras - stream: bool - invoke_from: InvokeFrom - - # extra parameters, like: auto_generate_conversation_name - extras: dict[str, Any] = {} + model: ModelConfigEntity + prompt_template: PromptTemplateEntity + dataset: Optional[DatasetEntity] = None + external_data_variables: list[ExternalDataVariableEntity] = [] + + +class WorkflowUIBasedAppConfig(AppConfig): + """ + Workflow UI Based App Config Entity. + """ + workflow_id: str diff --git a/api/core/app/app_config/features/__init__.py b/api/core/app/app_config/features/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/file_upload/__init__.py b/api/core/app/app_config/features/file_upload/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py new file mode 100644 index 00000000000000..63830696ffd28c --- /dev/null +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -0,0 +1,61 @@ +from typing import Optional + +from core.app.app_config.entities import FileUploadEntity + + +class FileUploadConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[FileUploadEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + file_upload_dict = config.get('file_upload') + if file_upload_dict: + if 'image' in file_upload_dict and file_upload_dict['image']: + if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + return FileUploadEntity( + image_config={ + 'number_limits': file_upload_dict['image']['number_limits'], + 'detail': file_upload_dict['image']['detail'], + 'transfer_methods': file_upload_dict['image']['transfer_methods'] + } + ) + + return None + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for file upload feature + + :param config: app model config args + """ + if not config.get("file_upload"): + config["file_upload"] = {} + + if not isinstance(config["file_upload"], dict): + raise ValueError("file_upload must be of dict type") + + # check image config + if not config["file_upload"].get("image"): + config["file_upload"]["image"] = {"enabled": False} + + if config['file_upload']['image']['enabled']: + number_limits = config['file_upload']['image']['number_limits'] + if number_limits < 1 or number_limits > 6: + raise ValueError("number_limits must be in [1, 6]") + + detail = config['file_upload']['image']['detail'] + if detail not in ['high', 'low']: + raise ValueError("detail must be in ['high', 'low']") + + transfer_methods = config['file_upload']['image']['transfer_methods'] + if not isinstance(transfer_methods, list): + raise ValueError("transfer_methods must be of list type") + for method in transfer_methods: + if method not in ['remote_url', 'local_file']: + raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + + return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/__init__.py b/api/core/app/app_config/features/more_like_this/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py new file mode 100644 index 00000000000000..ec2a9a679611c7 --- /dev/null +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -0,0 +1,38 @@ +class MoreLikeThisConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + more_like_this = False + more_like_this_dict = config.get('more_like_this') + if more_like_this_dict: + if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: + more_like_this = True + + return more_like_this + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for more like this feature + + :param config: app model config args + """ + if not config.get("more_like_this"): + config["more_like_this"] = { + "enabled": False + } + + if not isinstance(config["more_like_this"], dict): + raise ValueError("more_like_this must be of dict type") + + if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: + config["more_like_this"]["enabled"] = False + + if not isinstance(config["more_like_this"]["enabled"], bool): + raise ValueError("enabled in more_like_this must be of boolean type") + + return config, ["more_like_this"] diff --git a/api/core/app/app_config/features/opening_statement/__init__.py b/api/core/app/app_config/features/opening_statement/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py new file mode 100644 index 00000000000000..0d8a71bfcf4d8b --- /dev/null +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -0,0 +1,43 @@ + + +class OpeningStatementConfigManager: + @classmethod + def convert(cls, config: dict) -> tuple[str, list]: + """ + Convert model config to model config + + :param config: model config args + """ + # opening statement + opening_statement = config.get('opening_statement') + + # suggested questions + suggested_questions_list = config.get('suggested_questions') + + return opening_statement, suggested_questions_list + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for opening statement feature + + :param config: app model config args + """ + if not config.get("opening_statement"): + config["opening_statement"] = "" + + if not isinstance(config["opening_statement"], str): + raise ValueError("opening_statement must be of string type") + + # suggested_questions + if not config.get("suggested_questions"): + config["suggested_questions"] = [] + + if not isinstance(config["suggested_questions"], list): + raise ValueError("suggested_questions must be of list type") + + for question in config["suggested_questions"]: + if not isinstance(question, str): + raise ValueError("Elements in suggested_questions list must be of string type") + + return config, ["opening_statement", "suggested_questions"] diff --git a/api/core/app/app_config/features/retrieval_resource/__init__.py b/api/core/app/app_config/features/retrieval_resource/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py new file mode 100644 index 00000000000000..0694cb954e47e9 --- /dev/null +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -0,0 +1,33 @@ +class RetrievalResourceConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + show_retrieve_source = False + retriever_resource_dict = config.get('retriever_resource') + if retriever_resource_dict: + if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: + show_retrieve_source = True + + return show_retrieve_source + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for retriever resource feature + + :param config: app model config args + """ + if not config.get("retriever_resource"): + config["retriever_resource"] = { + "enabled": False + } + + if not isinstance(config["retriever_resource"], dict): + raise ValueError("retriever_resource must be of dict type") + + if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: + config["retriever_resource"]["enabled"] = False + + if not isinstance(config["retriever_resource"]["enabled"], bool): + raise ValueError("enabled in retriever_resource must be of boolean type") + + return config, ["retriever_resource"] diff --git a/api/core/app/app_config/features/speech_to_text/__init__.py b/api/core/app/app_config/features/speech_to_text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py new file mode 100644 index 00000000000000..b98699bfffdc87 --- /dev/null +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -0,0 +1,38 @@ +class SpeechToTextConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + speech_to_text = False + speech_to_text_dict = config.get('speech_to_text') + if speech_to_text_dict: + if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: + speech_to_text = True + + return speech_to_text + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for speech to text feature + + :param config: app model config args + """ + if not config.get("speech_to_text"): + config["speech_to_text"] = { + "enabled": False + } + + if not isinstance(config["speech_to_text"], dict): + raise ValueError("speech_to_text must be of dict type") + + if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: + config["speech_to_text"]["enabled"] = False + + if not isinstance(config["speech_to_text"]["enabled"], bool): + raise ValueError("enabled in speech_to_text must be of boolean type") + + return config, ["speech_to_text"] diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py b/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py new file mode 100644 index 00000000000000..5aacd3b32d3e52 --- /dev/null +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -0,0 +1,39 @@ +class SuggestedQuestionsAfterAnswerConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + suggested_questions_after_answer = False + suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + if suggested_questions_after_answer_dict: + if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: + suggested_questions_after_answer = True + + return suggested_questions_after_answer + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for suggested questions feature + + :param config: app model config args + """ + if not config.get("suggested_questions_after_answer"): + config["suggested_questions_after_answer"] = { + "enabled": False + } + + if not isinstance(config["suggested_questions_after_answer"], dict): + raise ValueError("suggested_questions_after_answer must be of dict type") + + if "enabled" not in config["suggested_questions_after_answer"] or not \ + config["suggested_questions_after_answer"]["enabled"]: + config["suggested_questions_after_answer"]["enabled"] = False + + if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): + raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + + return config, ["suggested_questions_after_answer"] diff --git a/api/core/app/app_config/features/text_to_speech/__init__.py b/api/core/app/app_config/features/text_to_speech/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py new file mode 100644 index 00000000000000..1ff31034ad48e8 --- /dev/null +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -0,0 +1,49 @@ +from core.app.app_config.entities import TextToSpeechEntity + + +class TextToSpeechConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + text_to_speech = False + text_to_speech_dict = config.get('text_to_speech') + if text_to_speech_dict: + if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: + text_to_speech = TextToSpeechEntity( + enabled=text_to_speech_dict.get('enabled'), + voice=text_to_speech_dict.get('voice'), + language=text_to_speech_dict.get('language'), + ) + + return text_to_speech + + @classmethod + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for text to speech feature + + :param config: app model config args + """ + if not config.get("text_to_speech"): + config["text_to_speech"] = { + "enabled": False, + "voice": "", + "language": "" + } + + if not isinstance(config["text_to_speech"], dict): + raise ValueError("text_to_speech must be of dict type") + + if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: + config["text_to_speech"]["enabled"] = False + config["text_to_speech"]["voice"] = "" + config["text_to_speech"]["language"] = "" + + if not isinstance(config["text_to_speech"]["enabled"], bool): + raise ValueError("enabled in text_to_speech must be of boolean type") + + return config, ["text_to_speech"] diff --git a/api/core/app/app_config/workflow_ui_based_app/__init__.py b/api/core/app/app_config/workflow_ui_based_app/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py b/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py new file mode 100644 index 00000000000000..4b117d87f8c157 --- /dev/null +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -0,0 +1,22 @@ +from core.app.app_config.entities import VariableEntity +from models.workflow import Workflow + + +class WorkflowVariablesConfigManager: + @classmethod + def convert(cls, workflow: Workflow) -> list[VariableEntity]: + """ + Convert workflow start variables to variables + + :param workflow: workflow instance + """ + variables = [] + + # find start node + user_input_form = workflow.user_input_form() + + # variables + for variable in user_input_form: + variables.append(VariableEntity(**variable)) + + return variables diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md new file mode 100644 index 00000000000000..856690dc57c6c9 --- /dev/null +++ b/api/core/app/apps/README.md @@ -0,0 +1,48 @@ +## Guidelines for Database Connection Management in App Runner and Task Pipeline + +Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. + +Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors. + +Examples: + +1. Creating a new record: + + ```python + app = App(id=1) + db.session.add(app) + db.session.commit() + db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close + + # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). + + db.session.close() + + return app.id + ``` + +2. Fetching a record from the table: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + created_at = app.created_at + + db.session.close() + + # Handle tasks (include long-running). + + ``` + +3. Updating a table field: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + app.updated_at = time.utcnow() + db.session.commit() + db.session.close() + + return app_id + ``` + diff --git a/api/core/app/apps/__init__.py b/api/core/app/apps/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/advanced_chat/__init__.py b/api/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py new file mode 100644 index 00000000000000..3ac26ebe80c4ec --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -0,0 +1,97 @@ + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import App, AppMode +from models.workflow import Workflow + + +class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): + """ + Advanced Chatbot App Config Entity. + """ + pass + + +class AdvancedChatAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, + workflow: Workflow) -> AdvancedChatAppConfig: + features_dict = workflow.features_dict + + app_config = AdvancedChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=features_dict + ), + variables=WorkflowVariablesConfigManager.convert( + workflow=workflow + ), + additional_features=cls.convert_features(features_dict) + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for advanced chat app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: if True, only structure validation will be performed + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py new file mode 100644 index 00000000000000..1a33a3230bd773 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -0,0 +1,218 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except GenerateTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param user: account or end user + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedException() + else: + logger.exception(e) + raise e diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py new file mode 100644 index 00000000000000..5f5fd7010c1c42 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -0,0 +1,210 @@ +import logging +import time +from typing import Optional, cast + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent +from core.moderation.base import ModerationException +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.model import App, Conversation, Message +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppRunner(AppRunner): + """ + AdvancedChat Application Runner + """ + + def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(AdvancedChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # moderation + if self.handle_input_moderation( + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=message, + query=query, + queue_manager=queue_manager, + app_generate_entity=application_generate_entity + ): + return + + db.session.close() + + # RUN WORKFLOW + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.run_workflow( + workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, + user_inputs=inputs, + system_inputs={ + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION: conversation.id, + }, + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id + ).first() + + # return workflow + return workflow + + def handle_input_moderation(self, queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: dict, + query: str) -> bool: + """ + Handle input moderation + :param queue_manager: application queue manager + :param app_record: app record + :param app_generate_entity: application generate entity + :param inputs: inputs + :param query: query + :return: + """ + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_generate_entity.app_config.tenant_id, + app_generate_entity=app_generate_entity, + inputs=inputs, + query=query, + ) + except ModerationException as e: + self._stream_output( + queue_manager=queue_manager, + text=str(e), + stream=app_generate_entity.stream, + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION + ) + return True + + return False + + def handle_annotation_reply(self, app_record: App, + message: Message, + query: str, + queue_manager: AppQueueManager, + app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + """ + Handle annotation reply + :param app_record: app record + :param message: message + :param query: query + :param queue_manager: application queue manager + :param app_generate_entity: application generate entity + """ + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=app_generate_entity.user_id, + invoke_from=app_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER + ) + + self._stream_output( + queue_manager=queue_manager, + text=annotation_reply.content, + stream=app_generate_entity.stream, + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + ) + return True + + return False + + def _stream_output(self, queue_manager: AppQueueManager, + text: str, + stream: bool, + stopped_by: QueueStopEvent.StopBy) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param text: text + :param stream: stream + :return: + """ + if stream: + index = 0 + for token in text: + queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.APPLICATION_MANAGER + ) + index += 1 + time.sleep(0.01) + + queue_manager.publish( + QueueStopEvent(stopped_by=stopped_by), + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py new file mode 100644 index 00000000000000..ca4b143027c512 --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -0,0 +1,987 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union, cast + +from pydantic import BaseModel, Extra + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk +from events.message_event import message_was_created +from extensions.ext_database import db +from models.account import Account +from models.model import Conversation, EndUser, Message, MessageFile +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) +from services.annotation_service import AppAnnotationService + +logger = logging.getLogger(__name__) + + +class StreamGenerateRoute(BaseModel): + """ + StreamGenerateRoute entity + """ + answer_node_id: str + generate_route: list[GenerateRouteChunk] + current_route_position: int = 0 + + +class TaskState(BaseModel): + """ + TaskState entity + """ + + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution_id: str + node_type: NodeType + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + answer: str = "" + metadata: dict = {} + usage: LLMUsage + + workflow_run_id: Optional[str] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 + + ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None + + current_stream_generate_state: Optional[StreamGenerateRoute] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + +class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): + """ + AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param user: user + :param stream: stream + """ + self._application_generate_entity = application_generate_entity + self._workflow = workflow + self._queue_manager = queue_manager + self._conversation = conversation + self._message = message + self._user = user + self._task_state = TaskState( + usage=LLMUsage.empty_usage() + ) + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + self._stream = stream + + if stream: + self._stream_generate_routes = self._get_stream_generate_routes() + else: + self._stream_generate_routes = None + + def process(self) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + db.session.refresh(self._workflow) + db.session.refresh(self._user) + db.session.close() + + if self._stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueWorkflowStartedEvent): + self._on_workflow_start() + elif isinstance(event, QueueNodeStartedEvent): + self._on_node_start(event) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + workflow_run = self._on_workflow_finished(event) + + if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + # Save message + self._save_message() + + response = { + 'event': 'message', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'mode': self._conversation.mode, + 'answer': self._task_state.answer, + 'metadata': {}, + 'created_at': int(self._message.created_at.timestamp()) + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + workflow_run = self._on_workflow_start() + + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_run.id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'sequence_number': workflow_run.sequence_number, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._on_node_start(event) + + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + workflow_node_execution = self._on_node_finished(event) + + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == NodeType.LLM.value: + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + workflow_run = self._on_workflow_finished(event) + + if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_run.id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + yield self._yield_response(workflow_run_response) + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': self._task_state.answer, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(replace_response) + + # Save message + self._save_message() + + response = { + 'event': 'message_end', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + yield self._yield_response(response) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueMessageFileEvent): + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter(MessageFile.id == event.message_file_id) + .first() + ) + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + + if message_file: + response = { + 'event': 'message_file', + 'conversation_id': self._conversation.id, + 'id': message_file.id, + 'type': message_file.type, + 'belongs_to': message_file.belongs_to or 'user', + 'url': url + } + + yield self._yield_response(response) + elif isinstance(event, QueueTextChunkEvent): + if not self._is_stream_out_support( + event=event + ): + continue + + delta_text = event.text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': event.text, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _on_workflow_start(self) -> WorkflowRun: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs={ + SystemVariable.QUERY: self._message.query, + SystemVariable.FILES: self._application_generate_entity.files, + SystemVariable.CONVERSATION: self._conversation.id, + } + ) + + self._task_state.workflow_run_id = workflow_run.id + + db.session.close() + + return workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + node_type=event.node_type, + start_at=time.perf_counter() + ) + + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + + self._task_state.total_steps += 1 + + db.session.close() + + # search stream_generate_routes if node id is answer start at node + if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: + self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + + # stream outputs from start + self._generate_stream_outputs_when_node_start() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: + current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=workflow_node_execution, + start_at=current_node_execution.start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + + if workflow_node_execution.node_type == NodeType.LLM.value: + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=workflow_node_execution, + start_at=current_node_execution.start_at, + error=event.error + ) + + # stream outputs when node finished + self._generate_stream_outputs_when_node_finished() + + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = (db.session.query(WorkflowRun) + .filter(WorkflowRun.id == self._task_state.workflow_run_id).first()) + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + + workflow_run = self._workflow_run_success( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=outputs + ) + + self._task_state.workflow_run_id = workflow_run.id + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + + db.session.close() + + return workflow_run + + def _save_message(self) -> None: + """ + Save message. + :return: + """ + self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + + self._message.answer = self._task_state.answer + self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.workflow_run_id = self._task_state.workflow_run_id + + if self._task_state.metadata and self._task_state.metadata.get('usage'): + usage = LLMUsage(**self._task_state.metadata['usage']) + + self._message.message_tokens = usage.prompt_tokens + self._message.message_unit_price = usage.prompt_unit_price + self._message.message_price_unit = usage.prompt_price_unit + self._message.answer_tokens = usage.completion_tokens + self._message.answer_unit_price = usage.completion_unit_price + self._message.answer_price_unit = usage.completion_price_unit + self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.total_price = usage.total_price + self._message.currency = usage.currency + + db.session.commit() + + message_was_created.send( + self._message, + application_generate_entity=self._application_generate_entity, + conversation=self._conversation, + is_first_message=self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras + ) + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'message', + 'id': self._message.id, + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': text, + 'created_at': int(self._message.created_at.timestamp()) + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + **data + } + + def _get_response_metadata(self) -> dict: + """ + Get response metadata by invoke from. + :return: + """ + metadata = {} + + # show_retrieve_source + if 'retriever_resources' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] + else: + metadata['retriever_resources'] = [] + for resource in self._task_state.metadata['retriever_resources']: + metadata['retriever_resources'].append({ + 'segment_id': resource['segment_id'], + 'position': resource['position'], + 'document_name': resource['document_name'], + 'score': resource['score'], + 'content': resource['content'], + }) + # show annotation reply + if 'annotation_reply' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] + + # show usage + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['usage'] = self._task_state.metadata['usage'] + + return metadata + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + queue_manager=self._queue_manager + ) + + def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]: + """ + Get stream generate routes. + :return: + """ + # find all answer nodes + graph = self._workflow.graph_dict + answer_node_configs = [ + node for node in graph['nodes'] + if node.get('data', {}).get('type') == NodeType.ANSWER.value + ] + + # parse stream output node value selectors of answer nodes + stream_generate_routes = {} + for node_config in answer_node_configs: + # get generate route for stream output + answer_node_id = node_config['id'] + generate_route = AnswerNode.extract_generate_route_selectors(node_config) + start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id) + if not start_node_id: + continue + + stream_generate_routes[start_node_id] = StreamGenerateRoute( + answer_node_id=answer_node_id, + generate_route=generate_route + ) + + return stream_generate_routes + + def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \ + -> Optional[str]: + """ + Get answer start at node id. + :param graph: graph + :param target_node_id: target node ID + :return: + """ + nodes = graph.get('nodes') + edges = graph.get('edges') + + # fetch all ingoing edges from source node + ingoing_edge = None + for edge in edges: + if edge.get('target') == target_node_id: + ingoing_edge = edge + break + + if not ingoing_edge: + return None + + source_node_id = ingoing_edge.get('source') + source_node = next((node for node in nodes if node.get('id') == source_node_id), None) + if not source_node: + return None + + node_type = source_node.get('data', {}).get('type') + if node_type in [ + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER + ]: + start_node_id = target_node_id + elif node_type == NodeType.START.value: + start_node_id = source_node_id + else: + start_node_id = self._get_answer_start_at_node_id(graph, source_node_id) + + return start_node_id + + def _generate_stream_outputs_when_node_start(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + for route_chunk in self._task_state.current_stream_generate_state.generate_route: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + else: + break + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _generate_stream_outputs_when_node_finished(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + route_chunks = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + route_chunk_node_id = value_selector[0] + + # check chunk node id is before current node id or equal to current node id + if route_chunk_node_id not in self._task_state.ran_node_execution_infos: + break + + latest_node_execution_info = self._task_state.latest_node_execution_info + + # get route chunk node execution info + route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] + if (route_chunk_node_execution_info.node_type == NodeType.LLM + and latest_node_execution_info.node_type == NodeType.LLM): + # only LLM support chunk stream output + self._task_state.current_stream_generate_state.current_route_position += 1 + continue + + # get route chunk node execution + route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first() + + outputs = route_chunk_node_execution.outputs_dict + + # get value from outputs + value = None + for key in value_selector[1:]: + if not value: + value = outputs.get(key) + else: + value = value.get(key) + + if value: + text = None + if isinstance(value, str | int | float): + text = str(value) + elif isinstance(value, object): # TODO FILE + # convert file to markdown + text = f'![]({value.get("url")})' + pass + + if text: + for token in text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.metadata: + return True + + if 'node_id' not in event.metadata: + return True + + node_type = event.metadata.get('node_type') + stream_output_value_selector = event.metadata.get('value_selector') + if not stream_output_value_selector: + return False + + if not self._task_state.current_stream_generate_state: + return False + + route_chunk = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position] + + if route_chunk.type != 'var': + return False + + if node_type != NodeType.LLM: + # only LLM support chunk stream output + return False + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + return False + + return True diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py new file mode 100644 index 00000000000000..972fda2d49a66c --- /dev/null +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -0,0 +1,128 @@ +from typing import Optional + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): + self._queue_manager = queue_manager + + def on_workflow_run_started(self) -> None: + """ + Workflow run started + """ + self._queue_manager.publish( + QueueWorkflowStartedEvent(), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_succeeded(self) -> None: + """ + Workflow run succeeded + """ + self._queue_manager.publish( + QueueWorkflowSucceededEvent(), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + self._queue_manager.publish( + QueueWorkflowFailedEvent( + error=error + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish( + QueueNodeStartedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + node_run_index=node_run_index, + predecessor_node_id=predecessor_node_id + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: + """ + Workflow node execute succeeded + """ + self._queue_manager.publish( + QueueNodeSucceededEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: + """ + Workflow node execute failed + """ + self._queue_manager.publish( + QueueNodeFailedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + error=error + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + """ + Publish text chunk + """ + self._queue_manager.publish( + QueueTextChunkEvent( + text=text, + metadata={ + "node_id": node_id, + **metadata + } + ), PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/agent_chat/__init__.py b/api/core/app/apps/agent_chat/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py new file mode 100644 index 00000000000000..232211c18b9843 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -0,0 +1,229 @@ +import uuid +from typing import Optional + +from core.agent.entities import AgentEntity +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import App, AppMode, AppModelConfig, Conversation + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentChatAppConfig(EasyUIBasedAppConfig): + """ + Agent Chatbot App Config Entity. + """ + agent: Optional[AgentEntity] = None + + +class AgentChatAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: + """ + Convert app model config to agent chat app config + :param app_model: app model + :param app_model_config: app model config + :param conversation: conversation + :param override_config_dict: app model config dict + :return: + """ + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict + + app_config = AgentChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + agent=AgentConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + + @classmethod + def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate agent_mode and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [member.value for member in + list(PlanningStrategy.__members__.values())]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetConfigManager.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py new file mode 100644 index 00000000000000..cc9b0785f56106 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -0,0 +1,196 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser + +logger = logging.getLogger(__name__) + + +class AgentChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = AgentChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: AgentChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except GenerateTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app/apps/agent_chat/app_runner.py similarity index 77% rename from api/core/app_runner/assistant_app_runner.py rename to api/core/app/apps/agent_chat/app_runner.py index 655a5a1c7c811d..0dc8a1e2184abe 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,11 +1,14 @@ import logging from typing import cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity -from core.features.assistant_cot_runner import AssistantCotApplicationRunner -from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.entities import AgentEntity +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity +from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -19,12 +22,13 @@ logger = logging.getLogger(__name__) -class AssistantApplicationRunner(AppRunner): + +class AgentChatAppRunner(AppRunner): """ - Assistant Application Runner + Agent Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + def run(self, application_generate_entity: AgentChatAppGenerateEntity, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -35,12 +39,13 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(AgentChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -52,8 +57,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -63,22 +68,22 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if application_generate_entity.conversation_id: # get memory of conversation (read-only) model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) memory = TokenBufferMemory( conversation=conversation, model_instance=model_instance ) - + # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) # memory(optional) prompt_messages, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -90,15 +95,15 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -116,13 +121,14 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) + self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, stream=application_generate_entity.stream @@ -130,7 +136,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -145,8 +151,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional), external data, dataset context(optional) prompt_messages, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -163,25 +169,25 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if hosting_moderation_result: return - agent_entity = app_orchestration_config.agent + agent_entity = app_config.agent # load tool variables tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, user_id=application_generate_entity.user_id, - tenant_id=application_generate_entity.tenant_id) + tenant_id=app_config.tenant_id) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) # init model instance model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -195,17 +201,17 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - db.session.refresh(conversation) - db.session.refresh(message) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + message = db.session.query(Message).filter(Message.id == message.id).first() db.session.close() # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - assistant_cot_runner = AssistantCotApplicationRunner( - tenant_id=application_generate_entity.tenant_id, + assistant_cot_runner = CotAgentRunner( + tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, + app_config=app_config, + model_config=application_generate_entity.model_config, config=agent_entity, queue_manager=queue_manager, message=message, @@ -223,11 +229,11 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, inputs=inputs, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - assistant_fc_runner = AssistantFunctionCallApplicationRunner( - tenant_id=application_generate_entity.tenant_id, + assistant_fc_runner = FunctionCallAgentRunner( + tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, + app_config=app_config, + model_config=application_generate_entity.model_config, config=agent_entity, queue_manager=queue_manager, message=message, @@ -288,7 +294,7 @@ def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversation 'pool': db_variables.variables }) - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, + def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, message: Message) -> LLMUsage: """ Get usage of all agent thoughts diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py new file mode 100644 index 00000000000000..750c6dae1036b5 --- /dev/null +++ b/api/core/app/apps/base_app_generator.py @@ -0,0 +1,42 @@ +from core.app.app_config.entities import AppConfig, VariableEntity + + +class BaseAppGenerator: + def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + # Filter input variables from form configuration, handle required fields, default values, and option values + variables = app_config.variables + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"{variable} is required in input form") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs + diff --git a/api/core/application_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py similarity index 51% rename from api/core/application_queue_manager.py rename to api/core/app/apps/base_app_queue_manager.py index 9590a1e7266f2f..43a44819f9495e 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,30 +1,20 @@ import queue import time +from abc import abstractmethod from collections.abc import Generator from enum import Enum from typing import Any from sqlalchemy.orm import DeclarativeMeta -from core.entities.application_entities import InvokeFrom -from core.entities.queue_entities import ( - AnnotationReplyEvent, +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( AppQueueEvent, - QueueAgentMessageEvent, - QueueAgentThoughtEvent, QueueErrorEvent, - QueueMessage, - QueueMessageEndEvent, - QueueMessageEvent, - QueueMessageFileEvent, - QueueMessageReplaceEvent, QueuePingEvent, - QueueRetrieverResourcesEvent, QueueStopEvent, ) -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client -from models.model import MessageAgentThought, MessageFile class PublishFrom(Enum): @@ -32,25 +22,20 @@ class PublishFrom(Enum): TASK_PIPELINE = 2 -class ApplicationQueueManager: +class AppQueueManager: def __init__(self, task_id: str, user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: + invoke_from: InvokeFrom) -> None: if not user_id: raise ValueError("user is required") self._task_id = task_id self._user_id = user_id self._invoke_from = invoke_from - self._conversation_id = str(conversation_id) - self._app_mode = app_mode - self._message_id = str(message_id) user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, + f"{user_prefix}-{self._user_id}") q = queue.Queue() @@ -84,7 +69,6 @@ def listen(self) -> Generator: QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE ) - self.stop_listen() if elapsed_time // 10 > last_ping_time: self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) @@ -97,89 +81,6 @@ def stop_listen(self) -> None: """ self._q.put(None) - def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: - """ - Publish chunk message to channel - - :param chunk: chunk - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageEvent( - chunk=chunk - ), pub_from) - - def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: - """ - Publish agent chunk message to channel - - :param chunk: chunk - :param pub_from: publish from - :return: - """ - self.publish(QueueAgentMessageEvent( - chunk=chunk - ), pub_from) - - def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None: - """ - Publish message replace - :param text: text - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageReplaceEvent( - text=text - ), pub_from) - - def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None: - """ - Publish retriever resources - :return: - """ - self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from) - - def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None: - """ - Publish annotation reply - :param message_annotation_id: message annotation id - :param pub_from: publish from - :return: - """ - self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) - - def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None: - """ - Publish message end - :param llm_result: llm result - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from) - self.stop_listen() - - def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None: - """ - Publish agent thought - :param message_agent_thought: message agent thought - :param pub_from: publish from - :return: - """ - self.publish(QueueAgentThoughtEvent( - agent_thought_id=message_agent_thought.id - ), pub_from) - - def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None: - """ - Publish agent thought - :param message_file: message file - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), pub_from) - def publish_error(self, e, pub_from: PublishFrom) -> None: """ Publish error @@ -190,7 +91,6 @@ def publish_error(self, e, pub_from: PublishFrom) -> None: self.publish(QueueErrorEvent( error=e ), pub_from) - self.stop_listen() def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ @@ -200,22 +100,17 @@ def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: :return: """ self._check_for_sqlalchemy_models(event.dict()) + self._publish(event, pub_from) - message = QueueMessage( - task_id=self._task_id, - message_id=self._message_id, - conversation_id=self._conversation_id, - app_mode=self._app_mode, - event=event - ) - - self._q.put(message) - - if isinstance(event, QueueStopEvent): - self.stop_listen() - - if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise ConversationTaskStoppedException() + @abstractmethod + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + raise NotImplementedError @classmethod def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: @@ -239,7 +134,7 @@ def _is_stopped(self) -> bool: Check if task is stopped :return: """ - stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) result = redis_client.get(stopped_cache_key) if result is not None: return True @@ -278,5 +173,5 @@ def _check_for_sqlalchemy_models(self, data: Any): "that cause thread safety issues is not allowed.") -class ConversationTaskStoppedException(Exception): +class GenerateTaskStoppedException(Exception): pass diff --git a/api/core/app_runner/app_runner.py b/api/core/app/apps/base_app_runner.py similarity index 71% rename from api/core/app_runner/app_runner.py rename to api/core/app/apps/base_app_runner.py index f9678b372fce6b..868e9e724f4081 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -2,19 +2,18 @@ from collections.abc import Generator from typing import Optional, Union, cast -from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import ( - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - ExternalDataVariableEntity, +from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + EasyUIBasedAppGenerateEntity, InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, + ModelConfigWithCredentialsEntity, ) -from core.features.annotation_reply import AnnotationReplyFeature -from core.features.external_data_fetch import ExternalDataFetchFeature -from core.features.hosting_moderation import HostingModerationFeature -from core.features.moderation import ModerationFeature +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature +from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,13 +21,16 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_transform import PromptTransform -from models.model import App, Message, MessageAnnotation +from core.moderation.input_moderation import InputModeration +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform +from models.model import App, AppMode, Message, MessageAnnotation class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -84,7 +86,7 @@ def get_pre_calculate_rest_tokens(self, app_record: App, return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigEntity, + def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance @@ -120,7 +122,7 @@ def recalc_llm_max_tokens(self, model_config: ModelConfigEntity, model_config.parameters[parameter_rule.name] = max_tokens def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -140,12 +142,11 @@ def organize_prompt_messages(self, app_record: App, :param memory: memory :return: """ - prompt_transform = PromptTransform() - # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( - app_mode=app_record.mode, + app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -155,13 +156,39 @@ def organize_prompt_messages(self, app_record: App, model_config=model_config ) else: - prompt_messages = prompt_transform.get_advanced_prompt( - app_mode=app_record.mode, - prompt_template_entity=prompt_template_entity, + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False + ) + ) + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + prompt_template = CompletionModelPromptTemplate( + text=advanced_completion_prompt_template.prompt + ) + + memory_config.role_prefix = MemoryConfig.RolePrefix( + user=advanced_completion_prompt_template.role_prefix.user, + assistant=advanced_completion_prompt_template.role_prefix.assistant + ) + else: + prompt_template = [] + for message in prompt_template_entity.advanced_chat_prompt_template.messages: + prompt_template.append(ChatModelMessage( + text=message.text, + role=message.role + )) + + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, inputs=inputs, - query=query, + query=query if query else '', files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) @@ -169,8 +196,8 @@ def organize_prompt_messages(self, app_record: App, return prompt_messages, stop - def direct_output(self, queue_manager: ApplicationQueueManager, - app_orchestration_config: AppOrchestrationConfigEntity, + def direct_output(self, queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list, text: str, stream: bool, @@ -178,7 +205,7 @@ def direct_output(self, queue_manager: ApplicationQueueManager, """ Direct output :param queue_manager: application queue manager - :param app_orchestration_config: app orchestration config + :param app_generate_entity: app generate entity :param prompt_messages: prompt messages :param text: text :param stream: stream @@ -188,29 +215,36 @@ def direct_output(self, queue_manager: ApplicationQueueManager, if stream: index = 0 for token in text: - queue_manager.publish_chunk_message(LLMResultChunk( - model=app_orchestration_config.model_config.model, + chunk = LLMResultChunk( + model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage(content=token) ) - ), PublishFrom.APPLICATION_MANAGER) + ) + + queue_manager.publish( + QueueLLMChunkEvent( + chunk=chunk + ), PublishFrom.APPLICATION_MANAGER + ) index += 1 time.sleep(0.01) - queue_manager.publish_message_end( - llm_result=LLMResult( - model=app_orchestration_config.model_config.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() - ), - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=app_generate_entity.model_config.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=text), + usage=usage if usage else LLMUsage.empty_usage() + ), + ), PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, stream: bool, agent: bool = False) -> None: """ @@ -234,7 +268,7 @@ def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], ) def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result direct @@ -242,13 +276,14 @@ def _handle_invoke_result_direct(self, invoke_result: LLMResult, :param queue_manager: application queue manager :return: """ - queue_manager.publish_message_end( - llm_result=invoke_result, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=invoke_result, + ), PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result @@ -262,9 +297,17 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, usage = None for result in invoke_result: if not agent: - queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueLLMChunkEvent( + chunk=result + ), PublishFrom.APPLICATION_MANAGER + ) else: - queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueAgentMessageEvent( + chunk=result + ), PublishFrom.APPLICATION_MANAGER + ) text += result.delta.message.content @@ -287,36 +330,37 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, usage=usage ) - queue_manager.publish_message_end( - llm_result=llm_result, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=llm_result, + ), PublishFrom.APPLICATION_MANAGER ) def moderation_for_inputs(self, app_id: str, tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, + app_generate_entity: AppGenerateEntity, inputs: dict, query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity + :param app_generate_entity: app generate entity :param inputs: inputs :param query: query :return: """ - moderation_feature = ModerationFeature() + moderation_feature = InputModeration() return moderation_feature.check( app_id=app_id, tenant_id=tenant_id, - app_orchestration_config_entity=app_orchestration_config_entity, + app_config=app_generate_entity.app_config, inputs=inputs, - query=query, + query=query if query else '' ) - def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -334,7 +378,7 @@ def check_hosting_moderation(self, application_generate_entity: ApplicationGener if moderation_result: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=application_generate_entity.app_orchestration_config_entity, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text="I apologize for any confusion, " \ "but I'm an AI assistant to be helpful, harmless, and honest.", @@ -358,7 +402,7 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str, :param query: the query :return: the filled inputs """ - external_data_fetch_feature = ExternalDataFetchFeature() + external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( tenant_id=tenant_id, app_id=app_id, @@ -388,4 +432,4 @@ def query_app_annotations_to_reply(self, app_record: App, query=query, user_id=user_id, invoke_from=invoke_from - ) \ No newline at end of file + ) diff --git a/api/core/app/apps/chat/__init__.py b/api/core/app/apps/chat/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py new file mode 100644 index 00000000000000..553cf34ee9b142 --- /dev/null +++ b/api/core/app/apps/chat/app_config_manager.py @@ -0,0 +1,147 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import App, AppMode, AppModelConfig, Conversation + + +class ChatAppConfig(EasyUIBasedAppConfig): + """ + Chatbot App Config Entity. + """ + pass + + +class ChatAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None) -> ChatAppConfig: + """ + Convert app model config to chat app config + :param app_model: app model + :param app_model_config: app model config + :param conversation: conversation + :param override_config_dict: app model config dict + :return: + """ + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict + + app_config = ChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, + config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py new file mode 100644 index 00000000000000..58287ba6587d09 --- /dev/null +++ b/api/core/app/apps/chat/app_generator.py @@ -0,0 +1,196 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser + +logger = logging.getLogger(__name__) + + +class ChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = ChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except GenerateTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app/apps/chat/app_runner.py similarity index 54% rename from api/core/app_runner/basic_app_runner.py rename to api/core/app/apps/chat/app_runner.py index d3c91337c8f5c1..d51f3db5409eec 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,28 +1,31 @@ import logging -from typing import Optional - -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from typing import cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.apps.chat.app_config_manager import ChatAppConfig +from core.app.entities.app_invoke_entities import ( + ChatAppGenerateEntity, +) +from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException -from core.prompt.prompt_transform import AppMode +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Conversation, Message logger = logging.getLogger(__name__) -class BasicApplicationRunner(AppRunner): +class ChatAppRunner(AppRunner): """ - Basic Application Runner + Chat Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + def run(self, application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -33,12 +36,13 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(ChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -50,8 +54,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -61,8 +65,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, if application_generate_entity.conversation_id: # get memory of conversation (read-only) model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) memory = TokenBufferMemory( @@ -75,8 +79,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -88,15 +92,15 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -114,13 +118,14 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) + self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, stream=application_generate_entity.stream @@ -128,7 +133,7 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -140,19 +145,24 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # get context from datasets context = None - if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + if app_config.dataset and app_config.dataset.dataset_ids: + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, - model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + model_config=application_generate_entity.model_config, + config=app_config.dataset, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_config.additional_features.show_retrieve_source, + hit_callback=hit_callback, memory=memory ) @@ -161,8 +171,8 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # memory(optional), external data, dataset context(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -182,21 +192,21 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit self.recalc_llm_max_tokens( - model_config=app_orchestration_config.model_config, + model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) # Invoke model model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) db.session.close() invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=application_generate_entity.model_config.parameters, stop=stop, stream=application_generate_entity.stream, user=application_generate_entity.user_id, @@ -208,56 +218,3 @@ def run(self, application_generate_entity: ApplicationGenerateEntity, queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: ApplicationQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrievalFeature() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) - \ No newline at end of file diff --git a/api/core/app/apps/completion/__init__.py b/api/core/app/apps/completion/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py new file mode 100644 index 00000000000000..b98a4c16aaf996 --- /dev/null +++ b/api/core/app/apps/completion/app_config_manager.py @@ -0,0 +1,125 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import App, AppMode, AppModelConfig + + +class CompletionAppConfig(EasyUIBasedAppConfig): + """ + Completion App Config Entity. + """ + pass + + +class CompletionAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, + app_model_config: AppModelConfig, + override_config_dict: Optional[dict] = None) -> CompletionAppConfig: + """ + Convert app model config to completion app config + :param app_model: app model + :param app_model_config: app model config + :param override_config_dict: app model config dict + :return: + """ + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict + + app_config = CompletionAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for completion app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.COMPLETION + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, + config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # more_like_this + config, current_related_config_keys = MoreLikeThisConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py new file mode 100644 index 00000000000000..fb6246972075cf --- /dev/null +++ b/api/core/app/apps/completion/app_generator.py @@ -0,0 +1,293 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser, Message +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class CompletionAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = {} + + # get conversation + conversation = None + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = CompletionAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + inputs=self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get message + message = self._get_message(message_id) + + # chatbot app + runner = CompletionAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message=message + ) + except GenerateTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def generate_more_like_this(self, app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param message_id: message ID + :param user: account or end user + :param invoke_from: invoke from source + :param stream: is stream + """ + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + override_model_config_dict = app_model_config.to_dict() + model_dict = override_model_config_dict['model'] + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + override_model_config_dict['model'] = model_dict + + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + message.files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={} + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py new file mode 100644 index 00000000000000..649d73d96180fa --- /dev/null +++ b/api/core/app/apps/completion/app_runner.py @@ -0,0 +1,179 @@ +import logging +from typing import cast + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_runner import AppRunner +from core.app.apps.completion.app_config_manager import CompletionAppConfig +from core.app.entities.app_invoke_entities import ( + CompletionAppGenerateEntity, +) +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelInstance +from core.moderation.base import ModerationException +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from extensions.ext_database import db +from models.model import App, Message + +logger = logging.getLogger(__name__) + + +class CompletionAppRunner(AppRunner): + """ + Completion Application Runner + """ + + def run(self, application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(CompletionAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query + ) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + ) + except ModerationException as e: + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + + # get context from datasets + context = None + if app_config.dataset and app_config.dataset.dataset_ids: + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_config = app_config.dataset + if dataset_config and dataset_config.retrieve_config.query_variable: + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( + tenant_id=app_record.tenant_id, + model_config=application_generate_entity.model_config, + config=dataset_config, + query=query, + invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_config.additional_features.show_retrieve_source, + hit_callback=hit_callback + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, + inputs=inputs, + files=files, + query=query, + context=context + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recalc_llm_max_tokens( + model_config=application_generate_entity.model_config, + prompt_messages=prompt_messages + ) + + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model + ) + + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=application_generate_entity.model_config.parameters, + stop=stop, + stream=application_generate_entity.stream, + user=application_generate_entity.user_id, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream + ) + \ No newline at end of file diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py similarity index 78% rename from api/core/app_runner/generate_task_pipeline.py rename to api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 1cc56483ad3770..412029b02491f8 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -6,16 +6,20 @@ from pydantic import BaseModel -from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler -from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom -from core.entities.queue_entities import ( - AnnotationReplyEvent, +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( QueueAgentMessageEvent, QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, QueueErrorEvent, + QueueLLMChunkEvent, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, @@ -26,20 +30,17 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.prompt_template import PromptTemplateParser +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought, MessageFile +from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) @@ -53,13 +54,17 @@ class TaskState(BaseModel): metadata: dict = {} -class GenerateTaskPipeline: +class EasyUIBasedGenerateTaskPipeline: """ - GenerateTaskPipeline is a class that generate stream output and state management for Application. + EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - def __init__(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + def __init__(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -70,12 +75,13 @@ def __init__(self, application_generate_entity: ApplicationGenerateEntity, :param message: message """ self._application_generate_entity = application_generate_entity + self._model_config = application_generate_entity.model_config self._queue_manager = queue_manager self._conversation = conversation self._message = message self._task_state = TaskState( llm_result=LLMResult( - model=self._application_generate_entity.app_orchestration_config_entity.model_config.model, + model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() @@ -110,7 +116,7 @@ def _process_blocking_response(self) -> dict: raise self._handle_error(event) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -127,7 +133,7 @@ def _process_blocking_response(self) -> dict: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.app_orchestration_config_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -184,7 +190,7 @@ def _process_blocking_response(self) -> dict: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -210,7 +216,7 @@ def _process_stream_response(self) -> Generator: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.app_orchestration_config_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -263,7 +269,7 @@ def _process_stream_response(self) -> Generator: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: replace_response['conversation_id'] = self._conversation.id yield self._yield_response(replace_response) @@ -278,7 +284,7 @@ def _process_stream_response(self) -> Generator: 'message_id': self._message.id, } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -287,7 +293,7 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -325,7 +331,7 @@ def _process_stream_response(self) -> Generator: 'message_files': agent_thought.files } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -356,12 +362,12 @@ def _process_stream_response(self) -> Generator: 'url': url } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) - elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): + elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: @@ -374,14 +380,19 @@ def _process_stream_response(self) -> Generator: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_chunk_message(LLMResultChunk( - model=self._task_state.llm_result.model, - prompt_messages=self._task_state.llm_result.prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) - ), PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueLLMChunkEvent( + chunk=LLMResultChunk( + model=self._task_state.llm_result.model, + prompt_messages=self._task_state.llm_result.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) + ) + ) + ), PublishFrom.TASK_PIPELINE + ) + self._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE @@ -402,7 +413,7 @@ def _process_stream_response(self) -> Generator: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -422,7 +433,10 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() - self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) + self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + self._model_config.mode, + self._task_state.llm_result.prompt_messages + ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -433,6 +447,7 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price + self._message.currency = usage.currency db.session.commit() @@ -440,7 +455,10 @@ def _save_message(self, llm_result: LLMResult) -> None: self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, + is_first_message=self._application_generate_entity.app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT + ] and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) @@ -459,7 +477,7 @@ def _handle_chunk(self, text: str, agent: bool = False) -> dict: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id return response @@ -562,92 +580,21 @@ def _yield_response(self, response: dict) -> str: """ return "data: " + json.dumps(response) + "\n\n" - def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]: - """ - Prompt messages to prompt for saving. - :param prompt_messages: prompt messages - :return: - """ - prompts = [] - if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat': - for prompt_message in prompt_messages: - if prompt_message.role == PromptMessageRole.USER: - role = 'user' - elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' - elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' - else: - continue - - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - prompts.append({ - "role": role, - "text": text, - "files": files - }) - else: - prompt_message = prompt_messages[0] - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - params = { - "role": 'user', - "text": text, - } - - if files: - params['files'] = files - - prompts.append(params) - - return prompts - - def _init_output_moderation(self) -> Optional[OutputModerationHandler]: + def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. :return: """ - app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity - sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance if sensitive_word_avoidance: - return OutputModerationHandler( - tenant_id=self._application_generate_entity.tenant_id, - app_id=self._application_generate_entity.app_id, + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, rule=ModerationRule( type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config ), - on_message_replace_func=self._queue_manager.publish_message_replace + queue_manager=self._queue_manager ) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py new file mode 100644 index 00000000000000..5e676c40bd56f5 --- /dev/null +++ b/api/core/app/apps/message_based_app_generator.py @@ -0,0 +1,269 @@ +import json +import logging +from collections.abc import Generator +from typing import Optional, Union + +from sqlalchemy import and_ + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException +from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, + AppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError + +logger = logging.getLogger(__name__) + + +class MessageBasedAppGenerator(BaseAppGenerator): + + def _handle_response(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedException() + else: + logger.exception(e) + raise e + + def _get_conversation_by_user(self, app_model: App, conversation_id: str, + user: Union[Account, EndUser]) -> Conversation: + conversation_filter = [ + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.status == 'normal' + ] + + if isinstance(user, Account): + conversation_filter.append(Conversation.from_account_id == user.id) + else: + conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) + + conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() + + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != 'normal': + raise ConversationCompletedError() + + return conversation + + def _get_app_model_config(self, app_model: App, + conversation: Optional[Conversation] = None) \ + -> AppModelConfig: + if conversation: + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() + + if not app_model_config: + raise AppModelConfigBrokenError() + else: + if app_model.app_model_config_id is None: + raise AppModelConfigBrokenError() + + app_model_config = app_model.app_model_config + + if not app_model_config: + raise AppModelConfigBrokenError() + + return app_model_config + + def _init_generate_records(self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ], + conversation: Optional[Conversation] = None) \ + -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :return: + """ + app_config = application_generate_entity.app_config + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + from_source = 'api' + end_user_id = application_generate_entity.user_id + else: + from_source = 'console' + account_id = application_generate_entity.user_id + + if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): + app_model_config_id = None + override_model_configs = None + model_provider = None + model_id = None + else: + app_model_config_id = app_config.app_model_config_id + model_provider = application_generate_entity.model_config.provider + model_id = application_generate_entity.model_config.model + override_model_configs = None + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ + and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + override_model_configs = app_config.app_model_config_dict + + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not conversation: + conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=app_model_config_id, + model_provider=model_provider, + model_id=model_id, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_config.app_mode.value, + name='New conversation', + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status='normal', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + db.session.refresh(conversation) + + message = Message( + app_id=app_config.app_id, + model_provider=model_provider, + model_id=model_id, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency='USD', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id + ) + + db.session.add(message) + db.session.commit() + db.session.refresh(message) + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + belongs_to='user', + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if account_id else 'end_user'), + created_by=account_id or end_user_id, + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_config = application_generate_entity.app_config + introduction = app_config.additional_features.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction + + def _get_conversation(self, conversation_id: str) -> Conversation: + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + return conversation + + def _get_message(self, message_id: str) -> Message: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = ( + db.session.query(Message) + .filter(Message.id == message_id) + .first() + ) + + return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py new file mode 100644 index 00000000000000..f4ff44dddac9ef --- /dev/null +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -0,0 +1,61 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + MessageQueueMessage, + QueueAdvancedChatMessageEndEvent, + QueueErrorEvent, + QueueMessage, + QueueMessageEndEvent, + QueueStopEvent, +) + + +class MessageBasedAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom, + conversation_id: str, + app_mode: str, + message_id: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._conversation_id = str(conversation_id) + self._app_mode = app_mode + self._message_id = str(message_id) + + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + return MessageQueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event + ) + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = MessageQueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event + ) + + self._q.put(message) + + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueAdvancedChatMessageEndEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedException() + diff --git a/api/core/app/apps/workflow/__init__.py b/api/core/app/apps/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py new file mode 100644 index 00000000000000..91bab1b21896c9 --- /dev/null +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -0,0 +1,71 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import App, AppMode +from models.workflow import Workflow + + +class WorkflowAppConfig(WorkflowUIBasedAppConfig): + """ + Workflow App Config Entity. + """ + pass + + +class WorkflowAppConfigManager(BaseAppConfigManager): + @classmethod + def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: + features_dict = workflow.features_dict + + app_config = WorkflowAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=features_dict + ), + variables=WorkflowVariablesConfigManager.convert( + workflow=workflow + ), + additional_features=cls.convert_features(features_dict) + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for workflow app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py new file mode 100644 index 00000000000000..b1a70a83ba42e3 --- /dev/null +++ b/api/core/app/apps/workflow/app_generator.py @@ -0,0 +1,170 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowAppGenerator(BaseAppGenerator): + def generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + inputs = args['inputs'] + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs=self._get_cleaned_inputs(inputs, app_config), + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from + ) + + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :return: + """ + with flask_app.app_context(): + try: + # workflow app + runner = WorkflowAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager + ) + except GenerateTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param user: account or end user + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=stream + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedException() + else: + logger.exception(e) + raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py new file mode 100644 index 00000000000000..f448138b53c0c2 --- /dev/null +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -0,0 +1,46 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowSucceededEvent, + WorkflowQueueMessage, +) + + +class WorkflowAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom, + app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage( + task_id=self._task_id, + app_mode=self._app_mode, + event=event + ) + + self._q.put(message) + + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedException() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py new file mode 100644 index 00000000000000..922c3003bfba41 --- /dev/null +++ b/api/core/app/apps/workflow/app_runner.py @@ -0,0 +1,170 @@ +import logging +import time +from typing import Optional, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent +from core.moderation.base import ModerationException +from core.moderation.input_moderation import InputModeration +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowAppRunner: + """ + Workflow Application Runner + """ + + def run(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(WorkflowAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + inputs = application_generate_entity.inputs + files = application_generate_entity.files + + # moderation + if self.handle_input_moderation( + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs + ): + return + + db.session.close() + + # RUN WORKFLOW + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.run_workflow( + workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, + user_inputs=inputs, + system_inputs={ + SystemVariable.FILES: files + }, + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id + ).first() + + # return workflow + return workflow + + def handle_input_moderation(self, queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: WorkflowAppGenerateEntity, + inputs: dict) -> bool: + """ + Handle input moderation + :param queue_manager: application queue manager + :param app_record: app record + :param app_generate_entity: application generate entity + :param inputs: inputs + :return: + """ + try: + # process sensitive_word_avoidance + moderation_feature = InputModeration() + _, inputs, query = moderation_feature.check( + app_id=app_record.id, + tenant_id=app_generate_entity.app_config.tenant_id, + app_config=app_generate_entity.app_config, + inputs=inputs, + query='' + ) + except ModerationException as e: + if app_generate_entity.stream: + self._stream_output( + queue_manager=queue_manager, + text=str(e), + ) + + queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION), + PublishFrom.APPLICATION_MANAGER + ) + return True + + return False + + def _stream_output(self, queue_manager: AppQueueManager, + text: str) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param text: text + :return: + """ + index = 0 + for token in text: + queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.APPLICATION_MANAGER + ) + index += 1 + time.sleep(0.01) + + def moderation_for_inputs(self, app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: dict) -> tuple[bool, dict, str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_generate_entity: app generate entity + :param inputs: inputs + :return: + """ + moderation_feature = InputModeration() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_generate_entity.app_config, + inputs=inputs, + query='' + ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py new file mode 100644 index 00000000000000..cd1ea4c81eaf79 --- /dev/null +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -0,0 +1,588 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from pydantic import BaseModel, Extra + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueMessageReplaceEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVariable +from extensions.ext_database import db +from models.account import Account +from models.model import EndUser +from models.workflow import ( + Workflow, + WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) + +logger = logging.getLogger(__name__) + + +class TaskState(BaseModel): + """ + TaskState entity + """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution_id: str + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + answer: str = "" + metadata: dict = {} + + workflow_run_id: Optional[str] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 + + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + +class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): + """ + WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param user: user + :param stream: is stream + """ + self._application_generate_entity = application_generate_entity + self._workflow = workflow + self._queue_manager = queue_manager + self._user = user + self._task_state = TaskState() + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + self._stream = stream + + def process(self) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + db.session.refresh(self._workflow) + db.session.refresh(self._user) + db.session.close() + + if self._stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueWorkflowStartedEvent): + self._on_workflow_start() + elif isinstance(event, QueueNodeStartedEvent): + self._on_node_start(event) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + workflow_run = self._on_workflow_finished(event) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + response = { + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_run.id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + workflow_run = self._on_workflow_start() + + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_run.id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'sequence_number': workflow_run.sequence_number, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._on_node_start(event) + + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + workflow_node_execution = self._on_node_finished(event) + + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + workflow_run = self._on_workflow_finished(event) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'text_replace', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + 'data': { + 'text': self._task_state.answer + } + } + + yield self._yield_response(replace_response) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_run.id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) if workflow_run.finished_at else None + } + } + + yield self._yield_response(workflow_run_response) + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'text_replace', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + 'data': { + 'text': event.text + } + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _on_workflow_start(self) -> WorkflowRun: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs={ + SystemVariable.FILES: self._application_generate_entity.files + } + ) + + self._task_state.workflow_run_id = workflow_run.id + + db.session.close() + + return workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution_id=workflow_node_execution.id, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + + self._task_state.total_steps += 1 + + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=workflow_node_execution, + start_at=current_node_execution.start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=workflow_node_execution, + start_at=current_node_execution.start_at, + error=event.error + ) + + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + + workflow_run = self._workflow_run_success( + workflow_run=workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=outputs + ) + + self._task_state.workflow_run_id = workflow_run.id + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + + db.session.close() + + return workflow_run + + def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: + """ + Save workflow app log. + :return: + """ + invoke_from = self._application_generate_entity.invoke_from + if invoke_from == InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + elif invoke_from == InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + elif invoke_from == InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + else: + # not save log for debugging + return + + workflow_app_log = WorkflowAppLog( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + created_from=created_from.value, + created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), + created_by=self._user.id, + ) + db.session.add(workflow_app_log) + db.session.commit() + db.session.close() + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'text_chunk', + 'workflow_run_id': self._task_state.workflow_run_id, + 'task_id': self._application_generate_entity.task_id, + 'data': { + 'text': text + } + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + **data + } + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + queue_manager=self._queue_manager + ) diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py new file mode 100644 index 00000000000000..e5a8e8d3747c42 --- /dev/null +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -0,0 +1,119 @@ +from typing import Optional + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): + self._queue_manager = queue_manager + + def on_workflow_run_started(self) -> None: + """ + Workflow run started + """ + self._queue_manager.publish( + QueueWorkflowStartedEvent(), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_succeeded(self) -> None: + """ + Workflow run succeeded + """ + self._queue_manager.publish( + QueueWorkflowSucceededEvent(), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + self._queue_manager.publish( + QueueWorkflowFailedEvent( + error=error + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish( + QueueNodeStartedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + node_run_index=node_run_index, + predecessor_node_id=predecessor_node_id + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: + """ + Workflow node execute succeeded + """ + self._queue_manager.publish( + QueueNodeSucceededEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: + """ + Workflow node execute failed + """ + self._queue_manager.publish( + QueueNodeFailedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + error=error + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + """ + Publish text chunk + """ + pass diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py new file mode 100644 index 00000000000000..2b373d28e83957 --- /dev/null +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -0,0 +1,214 @@ +import json +import time +from datetime import datetime +from typing import Optional, Union + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeType +from extensions.ext_database import db +from models.account import Account +from models.model import EndUser +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) + + +class WorkflowBasedGenerateTaskPipeline: + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + db.session.refresh(workflow_run) + db.session.close() + + return workflow_run + + def _workflow_run_success(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Optional[dict] = None) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :return: + """ + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + workflow_run.outputs = outputs + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_run) + db.session.close() + + return workflow_run + + def _workflow_run_failed(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + status: WorkflowRunStatus, + error: str) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param status: status + :param error: error message + :return: + """ + workflow_run.status = status.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_run) + db.session.close() + + return workflow_run + + def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, + node_id: str, + node_type: NodeType, + node_title: str, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: + """ + Init workflow node execution from workflow run + :param workflow_run: workflow run + :param node_id: node id + :param node_type: node type + :param node_title: node title + :param node_run_index: run index + :param predecessor_node_id: predecessor node id if exists + :return: + """ + # init workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node_id, + index=node_run_index, + node_id=node_id, + node_type=node_type.value, + title=node_title, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by + ) + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + """ + Workflow node execution success + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param inputs: inputs + :param process_data: process data + :param outputs: outputs + :param execution_metadata: execution metadata + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ + if execution_metadata else None + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution + + def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + error: str) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param error: error message + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + db.session.refresh(workflow_node_execution) + db.session.close() + + return workflow_node_execution diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py new file mode 100644 index 00000000000000..01cbd7d2b2df47 --- /dev/null +++ b/api/core/app/entities/app_invoke_entities.py @@ -0,0 +1,135 @@ +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + +from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.entities.provider_configuration import ProviderModelBundle +from core.file.file_obj import FileObj +from core.model_runtime.entities.model_entities import AIModelEntity + + +class InvokeFrom(Enum): + """ + Invoke From. + """ + SERVICE_API = 'service-api' + WEB_APP = 'web-app' + EXPLORE = 'explore' + DEBUGGER = 'debugger' + + @classmethod + def value_of(cls, value: str) -> 'InvokeFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid invoke from value {value}') + + def to_source(self) -> str: + """ + Get source of invoke from. + + :return: source + """ + if self == InvokeFrom.WEB_APP: + return 'web_app' + elif self == InvokeFrom.DEBUGGER: + return 'dev' + elif self == InvokeFrom.EXPLORE: + return 'explore_app' + elif self == InvokeFrom.SERVICE_API: + return 'api' + + return 'dev' + + +class ModelConfigWithCredentialsEntity(BaseModel): + """ + Model Config With Credentials Entity. + """ + provider: str + model: str + model_schema: AIModelEntity + mode: str + provider_model_bundle: ProviderModelBundle + credentials: dict[str, Any] = {} + parameters: dict[str, Any] = {} + stop: list[str] = [] + + +class AppGenerateEntity(BaseModel): + """ + App Generate Entity. + """ + task_id: str + + # app config + app_config: AppConfig + + inputs: dict[str, str] + files: list[FileObj] = [] + user_id: str + + # extras + stream: bool + invoke_from: InvokeFrom + + # extra parameters, like: auto_generate_conversation_name + extras: dict[str, Any] = {} + + +class EasyUIBasedAppGenerateEntity(AppGenerateEntity): + """ + Chat Application Generate Entity. + """ + # app config + app_config: EasyUIBasedAppConfig + model_config: ModelConfigWithCredentialsEntity + + query: Optional[str] = None + + +class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Chat Application Generate Entity. + """ + conversation_id: Optional[str] = None + + +class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Completion Application Generate Entity. + """ + pass + + +class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Agent Chat Application Generate Entity. + """ + conversation_id: Optional[str] = None + + +class AdvancedChatAppGenerateEntity(AppGenerateEntity): + """ + Advanced Chat Application Generate Entity. + """ + # app config + app_config: WorkflowUIBasedAppConfig + + conversation_id: Optional[str] = None + query: Optional[str] = None + + +class WorkflowAppGenerateEntity(AppGenerateEntity): + """ + Workflow Application Generate Entity. + """ + # app config + app_config: WorkflowUIBasedAppConfig diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py new file mode 100644 index 00000000000000..5c31996fd345a6 --- /dev/null +++ b/api/core/app/entities/queue_entities.py @@ -0,0 +1,245 @@ +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType + + +class QueueEvent(Enum): + """ + QueueEvent enum + """ + LLM_CHUNK = "llm_chunk" + TEXT_CHUNK = "text_chunk" + AGENT_MESSAGE = "agent_message" + MESSAGE_REPLACE = "message_replace" + MESSAGE_END = "message_end" + ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_SUCCEEDED = "workflow_succeeded" + WORKFLOW_FAILED = "workflow_failed" + NODE_STARTED = "node_started" + NODE_SUCCEEDED = "node_succeeded" + NODE_FAILED = "node_failed" + RETRIEVER_RESOURCES = "retriever_resources" + ANNOTATION_REPLY = "annotation_reply" + AGENT_THOUGHT = "agent_thought" + MESSAGE_FILE = "message_file" + ERROR = "error" + PING = "ping" + STOP = "stop" + + +class AppQueueEvent(BaseModel): + """ + QueueEvent entity + """ + event: QueueEvent + + +class QueueLLMChunkEvent(AppQueueEvent): + """ + QueueLLMChunkEvent entity + """ + event = QueueEvent.LLM_CHUNK + chunk: LLMResultChunk + + +class QueueTextChunkEvent(AppQueueEvent): + """ + QueueTextChunkEvent entity + """ + event = QueueEvent.TEXT_CHUNK + text: str + metadata: Optional[dict] = None + + +class QueueAgentMessageEvent(AppQueueEvent): + """ + QueueMessageEvent entity + """ + event = QueueEvent.AGENT_MESSAGE + chunk: LLMResultChunk + + +class QueueMessageReplaceEvent(AppQueueEvent): + """ + QueueMessageReplaceEvent entity + """ + event = QueueEvent.MESSAGE_REPLACE + text: str + + +class QueueRetrieverResourcesEvent(AppQueueEvent): + """ + QueueRetrieverResourcesEvent entity + """ + event = QueueEvent.RETRIEVER_RESOURCES + retriever_resources: list[dict] + + +class QueueAnnotationReplyEvent(AppQueueEvent): + """ + QueueAnnotationReplyEvent entity + """ + event = QueueEvent.ANNOTATION_REPLY + message_annotation_id: str + + +class QueueMessageEndEvent(AppQueueEvent): + """ + QueueMessageEndEvent entity + """ + event = QueueEvent.MESSAGE_END + llm_result: Optional[LLMResult] = None + + +class QueueAdvancedChatMessageEndEvent(AppQueueEvent): + """ + QueueAdvancedChatMessageEndEvent entity + """ + event = QueueEvent.ADVANCED_CHAT_MESSAGE_END + + +class QueueWorkflowStartedEvent(AppQueueEvent): + """ + QueueWorkflowStartedEvent entity + """ + event = QueueEvent.WORKFLOW_STARTED + + +class QueueWorkflowSucceededEvent(AppQueueEvent): + """ + QueueWorkflowSucceededEvent entity + """ + event = QueueEvent.WORKFLOW_SUCCEEDED + + +class QueueWorkflowFailedEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + event = QueueEvent.WORKFLOW_FAILED + error: str + + +class QueueNodeStartedEvent(AppQueueEvent): + """ + QueueNodeStartedEvent entity + """ + event = QueueEvent.NODE_STARTED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + node_run_index: int = 1 + predecessor_node_id: Optional[str] = None + + +class QueueNodeSucceededEvent(AppQueueEvent): + """ + QueueNodeSucceededEvent entity + """ + event = QueueEvent.NODE_SUCCEEDED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + execution_metadata: Optional[dict] = None + + error: Optional[str] = None + + +class QueueNodeFailedEvent(AppQueueEvent): + """ + QueueNodeFailedEvent entity + """ + event = QueueEvent.NODE_FAILED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + + inputs: Optional[dict] = None + process_data: Optional[dict] = None + + error: str + + +class QueueAgentThoughtEvent(AppQueueEvent): + """ + QueueAgentThoughtEvent entity + """ + event = QueueEvent.AGENT_THOUGHT + agent_thought_id: str + + +class QueueMessageFileEvent(AppQueueEvent): + """ + QueueAgentThoughtEvent entity + """ + event = QueueEvent.MESSAGE_FILE + message_file_id: str + + +class QueueErrorEvent(AppQueueEvent): + """ + QueueErrorEvent entity + """ + event = QueueEvent.ERROR + error: Any + + +class QueuePingEvent(AppQueueEvent): + """ + QueuePingEvent entity + """ + event = QueueEvent.PING + + +class QueueStopEvent(AppQueueEvent): + """ + QueueStopEvent entity + """ + class StopBy(Enum): + """ + Stop by enum + """ + USER_MANUAL = "user-manual" + ANNOTATION_REPLY = "annotation-reply" + OUTPUT_MODERATION = "output-moderation" + INPUT_MODERATION = "input-moderation" + + event = QueueEvent.STOP + stopped_by: StopBy + + +class QueueMessage(BaseModel): + """ + QueueMessage entity + """ + task_id: str + app_mode: str + event: AppQueueEvent + + +class MessageQueueMessage(QueueMessage): + """ + MessageQueueMessage entity + """ + message_id: str + conversation_id: str + + +class WorkflowQueueMessage(QueueMessage): + """ + WorkflowQueueMessage entity + """ + pass diff --git a/api/core/app/features/__init__.py b/api/core/app/features/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/app/features/annotation_reply/__init__.py b/api/core/app/features/annotation_reply/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py similarity index 98% rename from api/core/features/annotation_reply.py rename to api/core/app/features/annotation_reply/annotation_reply.py index fd516e465ff38c..19ff94de5e8d58 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/app/features/hosting_moderation/__init__.py b/api/core/app/features/hosting_moderation/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py similarity index 71% rename from api/core/features/hosting_moderation.py rename to api/core/app/features/hosting_moderation/hosting_moderation.py index d8ae7adcac5d47..ec316248a27afe 100644 --- a/api/core/features/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage @@ -8,7 +8,7 @@ class HostingModerationFeature: - def check(self, application_generate_entity: ApplicationGenerateEntity, + def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -16,8 +16,7 @@ def check(self, application_generate_entity: ApplicationGenerateEntity, :param prompt_messages: prompt messages :return: """ - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - model_config = app_orchestration_config.model_config + model_config = application_generate_entity.model_config text = "" for prompt_message in prompt_messages: diff --git a/api/core/application_manager.py b/api/core/application_manager.py deleted file mode 100644 index 9aca61c7bb40f2..00000000000000 --- a/api/core/application_manager.py +++ /dev/null @@ -1,755 +0,0 @@ -import json -import logging -import threading -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union, cast - -from flask import Flask, current_app -from pydantic import ValidationError - -from core.app_runner.assistant_app_runner import AssistantApplicationRunner -from core.app_runner.basic_app_runner import BasicApplicationRunner -from core.app_runner.generate_task_pipeline import GenerateTaskPipeline -from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom -from core.entities.application_entities import ( - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - AgentEntity, - AgentPromptEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, - SensitiveWordAvoidanceEntity, - TextToSpeechEntity, -) -from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileObj -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_template import PromptTemplateParser -from core.provider_manager import ProviderManager -from core.tools.prompt.template import REACT_PROMPT_TEMPLATES -from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile - -logger = logging.getLogger(__name__) - - -class ApplicationManager: - """ - This class is responsible for managing application - """ - - def generate(self, tenant_id: str, - app_id: str, - app_model_config_id: str, - app_model_config_dict: dict, - app_model_config_override: bool, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - inputs: dict[str, str], - query: Optional[str] = None, - files: Optional[list[FileObj]] = None, - conversation: Optional[Conversation] = None, - stream: bool = False, - extras: Optional[dict[str, Any]] = None) \ - -> Union[dict, Generator]: - """ - Generate App response. - - :param tenant_id: workspace ID - :param app_id: app ID - :param app_model_config_id: app model config id - :param app_model_config_dict: app model config dict - :param app_model_config_override: app model config override - :param user: account or end user - :param invoke_from: invoke from source - :param inputs: inputs - :param query: query - :param files: file obj list - :param conversation: conversation - :param stream: is stream - :param extras: extras - """ - # init task id - task_id = str(uuid.uuid4()) - - # init application generate entity - application_generate_entity = ApplicationGenerateEntity( - task_id=task_id, - tenant_id=tenant_id, - app_id=app_id, - app_model_config_id=app_model_config_id, - app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self._convert_from_app_model_config_dict( - tenant_id=tenant_id, - app_model_config_dict=app_model_config_dict - ), - app_model_config_override=app_model_config_override, - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else inputs, - query=query.replace('\x00', '') if query else None, - files=files if files else [], - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras - ) - - if not stream and application_generate_entity.app_orchestration_config_entity.agent: - raise ValueError("Agent app is not supported in blocking mode.") - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = ApplicationQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) - - worker_thread.start() - - # return response or stream generator - return self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - stream=stream - ) - - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation_id: str, - message_id: str) -> None: - """ - Generate worker in a new thread. - :param flask_app: Flask app - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation_id: conversation ID - :param message_id: message ID - :return: - """ - with flask_app.app_context(): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - if application_generate_entity.app_orchestration_config_entity.agent: - # agent app - runner = AssistantApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - else: - # basic app - runner = BasicApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - except ConversationTaskStoppedException: - pass - except InvokeAuthorizationError: - queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER - ) - except ValidationError as e: - logger.exception("Validation Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except Exception as e: - logger.exception("Unknown Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() - - def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: - """ - Handle response. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param stream: is stream - :return: - """ - # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - - try: - return generate_task_pipeline.process(stream=stream) - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() - else: - logger.exception(e) - raise e - - def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ - -> AppOrchestrationConfigEntity: - """ - Convert app model config dict to entity. - :param tenant_id: tenant ID - :param app_model_config_dict: app model config dict - :raises ProviderTokenNotInitError: provider token not init error - :return: app orchestration config entity - """ - properties = {} - - copy_app_model_config_dict = app_model_config_dict.copy() - - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=copy_app_model_config_dict['model']['provider'], - model_type=ModelType.LLM - ) - - provider_name = provider_model_bundle.configuration.provider.provider - model_name = copy_app_model_config_dict['model']['name'] - - model_type_instance = provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # check model credentials - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=copy_app_model_config_dict['model']['name'] - ) - - if model_credentials is None: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=copy_app_model_config_dict['model']['name'], - model_type=ModelType.LLM - ) - - if provider_model is None: - model_name = copy_app_model_config_dict['model']['name'] - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = copy_app_model_config_dict['model'].get('completion_params') - stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] - - # get model mode - model_mode = copy_app_model_config_dict['model'].get('mode') - if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=copy_app_model_config_dict['model']['name'], - credentials=model_credentials - ) - - model_mode = mode_enum.value - - model_schema = model_type_instance.get_model_schema( - copy_app_model_config_dict['model']['name'], - model_credentials - ) - - if not model_schema: - raise ValueError(f"Model {model_name} not exist.") - - properties['model_config'] = ModelConfigEntity( - provider=copy_app_model_config_dict['model']['provider'], - model=copy_app_model_config_dict['model']['name'], - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - # prompt template - prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type']) - if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: - simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "") - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) - else: - advanced_chat_prompt_template = None - chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {}) - if chat_prompt_config: - chat_prompt_messages = [] - for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) - - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) - - advanced_completion_prompt_template = None - completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {}) - if completion_prompt_config: - completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], - } - - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] - } - - advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( - **completion_prompt_template_params - ) - - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template - ) - - # external data variables - properties['external_data_variables'] = [] - - # old external_data_tools - external_data_tools = copy_app_model_config_dict.get('external_data_tools', []) - for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: - continue - - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] - ) - ) - - # current external_data_tools - for variable in copy_app_model_config_dict.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] - ) - ) - - # show retrieve source - show_retrieve_source = False - retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource') - if retriever_resource_dict: - if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: - show_retrieve_source = True - - properties['show_retrieve_source'] = show_retrieve_source - - dataset_ids = [] - if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): - datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) - - - for dataset in datasets.get('datasets', []): - keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': - continue - dataset = dataset['dataset'] - - if 'enabled' not in dataset or not dataset['enabled']: - continue - - dataset_id = dataset.get('id', None) - if dataset_id: - dataset_ids.append(dataset_id) - else: - datasets = {'strategy': 'router', 'datasets': []} - - if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ - and 'enabled' in copy_app_model_config_dict['agent_mode'] \ - and copy_app_model_config_dict['agent_mode']['enabled']: - - agent_dict = copy_app_model_config_dict.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if copy_app_model_config_dict['model']['provider'] == 'openai': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - - agent_tools = [] - for tool in agent_dict.get('tools', []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: - continue - - agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} - } - - agent_tools.append(AgentToolEntity(**agent_tool_properties)) - elif len(keys) == 1: - # old standard - key = list(tool.keys())[0] - - if key != 'dataset': - continue - - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - continue - - dataset_id = tool_item['id'] - dataset_ids.append(dataset_id) - - if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ - copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} - # check model mode - model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), - ) - else: - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), - ) - - properties['agent'] = AgentEntity( - provider=properties['model_config'].provider, - model=properties['model_config'].model, - strategy=strategy, - prompt=agent_prompt_entity, - tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) - ) - - if len(dataset_ids) > 0: - # dataset configs - dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) - query_variable = copy_app_model_config_dict.get('dataset_query_variable') - - if dataset_configs['retrieval_model'] == 'single': - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - single_strategy=datasets.get('strategy', 'router') - ) - ) - else: - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - top_k=dataset_configs.get('top_k'), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model') - ) - ) - - # file upload - file_upload_dict = copy_app_model_config_dict.get('file_upload') - if file_upload_dict: - if 'image' in file_upload_dict and file_upload_dict['image']: - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: - properties['file_upload'] = FileUploadEntity( - image_config={ - 'number_limits': file_upload_dict['image']['number_limits'], - 'detail': file_upload_dict['image']['detail'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] - } - ) - - # opening statement - properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement') - - # suggested questions after answer - suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer') - if suggested_questions_after_answer_dict: - if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: - properties['suggested_questions_after_answer'] = True - - # more like this - more_like_this_dict = copy_app_model_config_dict.get('more_like_this') - if more_like_this_dict: - if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: - properties['more_like_this'] = True - - # speech to text - speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text') - if speech_to_text_dict: - if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: - properties['speech_to_text'] = True - - # text to speech - text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech') - if text_to_speech_dict: - if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: - properties['text_to_speech'] = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), - ) - - # sensitive word avoidance - sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance') - if sensitive_word_avoidance_dict: - if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: - properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), - ) - - return AppOrchestrationConfigEntity(**properties) - - def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> tuple[Conversation, Message]: - """ - Initialize generate records - :param application_generate_entity: application generate entity - :return: - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - - model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_schema = model_type_instance.get_model_schema( - model=app_orchestration_config_entity.model_config.model, - credentials=app_orchestration_config_entity.model_config.credentials - ) - - app_record = (db.session.query(App) - .filter(App.id == application_generate_entity.app_id).first()) - - app_mode = app_record.mode - - # get from source - end_user_id = None - account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' - end_user_id = application_generate_entity.user_id - else: - from_source = 'console' - account_id = application_generate_entity.user_id - - override_model_configs = None - if application_generate_entity.app_model_config_override: - override_model_configs = application_generate_entity.app_model_config_dict - - introduction = '' - if app_mode == 'chat': - # get conversation introduction - introduction = self._get_conversation_introduction(application_generate_entity) - - if not application_generate_entity.conversation_id: - conversation = Conversation( - app_id=app_record.id, - app_model_config_id=application_generate_entity.app_model_config_id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_mode, - name='New conversation', - inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status='normal', - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - ) - - db.session.add(conversation) - db.session.commit() - db.session.refresh(conversation) - else: - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == application_generate_entity.conversation_id, - Conversation.app_id == app_record.id - ).first() - ) - - currency = model_schema.pricing.currency if model_schema.pricing else 'USD' - - message = Message( - app_id=app_record.id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query or "", - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency=currency, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - agent_based=app_orchestration_config_entity.agent is not None - ) - - db.session.add(message) - db.session.commit() - db.session.refresh(message) - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, - upload_file_id=file.upload_file_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message - - def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: - """ - Get conversation introduction - :param application_generate_entity: application generate entity - :return: conversation introduction - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - introduction = app_orchestration_config_entity.opening_statement - - if introduction: - try: - inputs = application_generate_entity.inputs - prompt_template = PromptTemplateParser(template=introduction) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - introduction = prompt_template.format(prompt_inputs) - except KeyError: - pass - - return introduction - - def _get_conversation(self, conversation_id: str) -> Conversation: - """ - Get conversation by conversation id - :param conversation_id: conversation id - :return: conversation - """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) - - return conversation - - def _get_message(self, message_id: str) -> Message: - """ - Get message by message id - :param message_id: message id - :return: message - """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) - - return message diff --git a/api/core/callback_handler/__init__.py b/api/core/callback_handler/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py deleted file mode 100644 index 1d25b8ab69d7e5..00000000000000 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ /dev/null @@ -1,262 +0,0 @@ -import json -import logging -import time -from typing import Any, Optional, Union, cast - -from langchain.agents import openai_functions_agent, openai_functions_multi_agent -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - -from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.callback_handler.entity.agent_loop import AgentLoop -from core.entities.application_entities import ModelConfigEntity -from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from extensions.ext_database import db -from models.model import Message, MessageAgentThought, MessageChain - - -class AgentLoopGatherCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - raise_error: bool = True - - def __init__(self, model_config: ModelConfigEntity, - queue_manager: ApplicationQueueManager, - message: Message, - message_chain: MessageChain) -> None: - """Initialize callback handler.""" - self.model_config = model_config - self.queue_manager = queue_manager - self.message = message - self.message_chain = message_chain - model_type_instance = self.model_config.provider_model_bundle.model_type_instance - self.model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - @property - def agent_loops(self) -> list[AgentLoop]: - return self._agent_loops - - def clear_agent_loops(self) -> None: - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return True - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return True - - def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None: - if not self._current_loop: - # Agent start with a LLM query - self._current_loop = AgentLoop( - position=len(self._agent_loops) + 1, - prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]), - status='llm_started', - started_at=time.perf_counter() - ) - - def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None: - if self._current_loop and self._current_loop.status == 'llm_started': - self._current_loop.status = 'llm_end' - if result.usage: - self._current_loop.prompt_tokens = result.usage.prompt_tokens - else: - self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens( - model=self.model_config.model, - credentials=self.model_config.credentials, - prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)] - ) - - completion_message = result.message - if completion_message.tool_calls: - self._current_loop.completion \ - = json.dumps({'function_call': completion_message.tool_calls}) - else: - self._current_loop.completion = completion_message.content - - if result.usage: - self._current_loop.completion_tokens = result.usage.completion_tokens - else: - self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens( - model=self.model_config.model, - credentials=self.model_config.credentials, - prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)] - ) - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - pass - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - logging.debug("Agent on_llm_error: %s", error) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - # kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '} - # input_str='action-input' - # serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'} - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = json.dumps({"query": action.tool_input} - if isinstance(action.tool_input, str) else action.tool_input) - completion = None - if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \ - or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction): - thought = action.log.strip() - completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']}) - else: - action_name_position = action.log.index("Action:") if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - - if self._current_loop and self._current_loop.status == 'llm_end': - self._current_loop.status = 'agent_action' - self._current_loop.thought = thought - self._current_loop.tool_name = tool - self._current_loop.tool_input = tool_input - if completion is not None: - self._current_loop.completion = completion - - self._message_agent_thought = self._init_agent_thought() - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - # kwargs={'name': 'Search'} - # llm_prefix='Thought:' - # observation_prefix='Observation: ' - # output='53 years' - - if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None': - self._current_loop.status = 'tool_end' - self._current_loop.tool_output = output - self._current_loop.completed = True - self._current_loop.completed_at = time.perf_counter() - self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at - - self._complete_agent_thought(self._message_agent_thought) - - self._agent_loops.append(self._current_loop) - self._current_loop = None - self._message_agent_thought = None - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - logging.debug("Agent on_tool_error: %s", error) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - # Final Answer - if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'): - self._current_loop.status = 'agent_finish' - self._current_loop.completed = True - self._current_loop.completed_at = time.perf_counter() - self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at - self._current_loop.thought = '[DONE]' - self._message_agent_thought = self._init_agent_thought() - - self._complete_agent_thought(self._message_agent_thought) - - self._agent_loops.append(self._current_loop) - self._current_loop = None - self._message_agent_thought = None - elif not self._current_loop and self._agent_loops: - self._agent_loops[-1].status = 'agent_finish' - - def _init_agent_thought(self) -> MessageAgentThought: - message_agent_thought = MessageAgentThought( - message_id=self.message.id, - message_chain_id=self.message_chain.id, - position=self._current_loop.position, - thought=self._current_loop.thought, - tool=self._current_loop.tool_name, - tool_input=self._current_loop.tool_input, - message=self._current_loop.prompt, - message_price_unit=0, - answer=self._current_loop.completion, - answer_price_unit=0, - created_by_role=('account' if self.message.from_source == 'console' else 'end_user'), - created_by=(self.message.from_account_id - if self.message.from_source == 'console' else self.message.from_end_user_id) - ) - - db.session.add(message_agent_thought) - db.session.commit() - - self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER) - - return message_agent_thought - - def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: - loop_message_tokens = self._current_loop.prompt_tokens - loop_answer_tokens = self._current_loop.completion_tokens - - # transform usage - llm_usage = self.model_type_instance._calc_response_usage( - self.model_config.model, - self.model_config.credentials, - loop_message_tokens, - loop_answer_tokens - ) - - message_agent_thought.observation = self._current_loop.tool_output - message_agent_thought.tool_process_data = '' # currently not support - message_agent_thought.message_token = loop_message_tokens - message_agent_thought.message_unit_price = llm_usage.prompt_unit_price - message_agent_thought.message_price_unit = llm_usage.prompt_price_unit - message_agent_thought.answer_token = loop_answer_tokens - message_agent_thought.answer_unit_price = llm_usage.completion_unit_price - message_agent_thought.answer_price_unit = llm_usage.completion_price_unit - message_agent_thought.latency = self._current_loop.latency - message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens - message_agent_thought.total_price = llm_usage.total_price - message_agent_thought.currency = llm_usage.currency - db.session.commit() diff --git a/api/core/callback_handler/entity/agent_loop.py b/api/core/callback_handler/entity/agent_loop.py deleted file mode 100644 index 56634bb19e4990..00000000000000 --- a/api/core/callback_handler/entity/agent_loop.py +++ /dev/null @@ -1,23 +0,0 @@ -from pydantic import BaseModel - - -class AgentLoop(BaseModel): - position: int = 1 - - thought: str = None - tool_name: str = None - tool_input: str = None - tool_output: str = None - - prompt: str = None - prompt_tokens: int = 0 - completion: str = None - completion_tokens: int = 0 - - latency: float = None - - status: str = 'llm_started' - completed: bool = False - - started_at: float = None - completed_at: float = None \ No newline at end of file diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 879c9df69dca54..8e1f496b226c14 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,7 @@ -from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import InvokeFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment @@ -10,7 +11,7 @@ class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: ApplicationQueueManager, + def __init__(self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, @@ -82,4 +83,7 @@ def return_retriever_resource_info(self, resource: list): db.session.add(dataset_retriever_resource) db.session.commit() - self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER) + self._queue_manager.publish( + QueueRetrieverResourcesEvent(retriever_resources=resource), + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py deleted file mode 100644 index 1f95471afb2c73..00000000000000 --- a/api/core/callback_handler/std_out_callback_handler.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import sys -from typing import Any, Optional, Union - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - - -class DifyStdOutCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.color = color - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - print_text("\n[on_chat_model_start]\n", color='blue') - for sub_messages in messages: - for sub_message in sub_messages: - print_text(str(sub_message) + "\n", color='blue') - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - """Print out the prompts.""" - print_text("\n[on_llm_start]\n", color='blue') - print_text(prompts[0] + "\n", color='blue') - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str( - response.llm_output) + "\n", color='blue') - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') - - def on_chain_start( - self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - chain_type = serialized['id'][-1] - print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - - def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink') - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_start] " + str(serialized), color='yellow') - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = action.tool_input - try: - action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - except ValueError: - thought = '' - - log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" - print_text("\n[on_agent_action]\n" + log + "\n", color='green') - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - print_text("\n[on_tool_end]\n", color='yellow') - if observation_prefix: - print_text(f"\n{observation_prefix}") - print_text(output, color='yellow') - if llm_prefix: - print_text(f"\n{llm_prefix}") - print_text("\n") - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow') - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run when agent ends.""" - print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chat_model(self) -> bool: - """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - -class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): - """Callback handler for streaming. Only works with LLMs that support streaming.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - sys.stdout.write(token) - sys.stdout.flush() diff --git a/api/core/entities/queue_entities.py b/api/core/entities/queue_entities.py deleted file mode 100644 index c1f8fb7e8964a9..00000000000000 --- a/api/core/entities/queue_entities.py +++ /dev/null @@ -1,133 +0,0 @@ -from enum import Enum -from typing import Any - -from pydantic import BaseModel - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk - - -class QueueEvent(Enum): - """ - QueueEvent enum - """ - MESSAGE = "message" - AGENT_MESSAGE = "agent_message" - MESSAGE_REPLACE = "message-replace" - MESSAGE_END = "message-end" - RETRIEVER_RESOURCES = "retriever-resources" - ANNOTATION_REPLY = "annotation-reply" - AGENT_THOUGHT = "agent-thought" - MESSAGE_FILE = "message-file" - ERROR = "error" - PING = "ping" - STOP = "stop" - - -class AppQueueEvent(BaseModel): - """ - QueueEvent entity - """ - event: QueueEvent - - -class QueueMessageEvent(AppQueueEvent): - """ - QueueMessageEvent entity - """ - event = QueueEvent.MESSAGE - chunk: LLMResultChunk - -class QueueAgentMessageEvent(AppQueueEvent): - """ - QueueMessageEvent entity - """ - event = QueueEvent.AGENT_MESSAGE - chunk: LLMResultChunk - - -class QueueMessageReplaceEvent(AppQueueEvent): - """ - QueueMessageReplaceEvent entity - """ - event = QueueEvent.MESSAGE_REPLACE - text: str - - -class QueueRetrieverResourcesEvent(AppQueueEvent): - """ - QueueRetrieverResourcesEvent entity - """ - event = QueueEvent.RETRIEVER_RESOURCES - retriever_resources: list[dict] - - -class AnnotationReplyEvent(AppQueueEvent): - """ - AnnotationReplyEvent entity - """ - event = QueueEvent.ANNOTATION_REPLY - message_annotation_id: str - - -class QueueMessageEndEvent(AppQueueEvent): - """ - QueueMessageEndEvent entity - """ - event = QueueEvent.MESSAGE_END - llm_result: LLMResult - - -class QueueAgentThoughtEvent(AppQueueEvent): - """ - QueueAgentThoughtEvent entity - """ - event = QueueEvent.AGENT_THOUGHT - agent_thought_id: str - -class QueueMessageFileEvent(AppQueueEvent): - """ - QueueAgentThoughtEvent entity - """ - event = QueueEvent.MESSAGE_FILE - message_file_id: str - -class QueueErrorEvent(AppQueueEvent): - """ - QueueErrorEvent entity - """ - event = QueueEvent.ERROR - error: Any - - -class QueuePingEvent(AppQueueEvent): - """ - QueuePingEvent entity - """ - event = QueueEvent.PING - - -class QueueStopEvent(AppQueueEvent): - """ - QueueStopEvent entity - """ - class StopBy(Enum): - """ - Stop by enum - """ - USER_MANUAL = "user-manual" - ANNOTATION_REPLY = "annotation-reply" - OUTPUT_MODERATION = "output-moderation" - - event = QueueEvent.STOP - stopped_by: StopBy - - -class QueueMessage(BaseModel): - """ - QueueMessage entity - """ - task_id: str - message_id: str - conversation_id: str - app_mode: str - event: AppQueueEvent diff --git a/api/core/features/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py similarity index 88% rename from api/core/features/external_data_fetch.py rename to api/core/external_data_tool/external_data_fetch.py index 7f23c8ed728096..8601cb34e79582 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,18 +1,17 @@ import concurrent -import json import logging from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app -from core.entities.application_entities import ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory logger = logging.getLogger(__name__) -class ExternalDataFetchFeature: +class ExternalDataFetch: def fetch(self, tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], @@ -28,12 +27,6 @@ def fetch(self, tenant_id: str, :param query: the query :return: the filled inputs """ - # Group tools by type and config - grouped_tools = {} - for tool in external_data_tools: - tool_key = (tool.type, json.dumps(tool.config, sort_keys=True)) - grouped_tools.setdefault(tool_key, []).append(tool) - results = {} with ThreadPoolExecutor() as executor: futures = {} diff --git a/api/core/features/dataset_retrieval/agent/agent_llm_callback.py b/api/core/features/dataset_retrieval/agent/agent_llm_callback.py deleted file mode 100644 index 5ec549de8ee5b1..00000000000000 --- a/api/core/features/dataset_retrieval/agent/agent_llm_callback.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -from typing import Optional - -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class AgentLLMCallback(Callback): - - def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None: - self.agent_callback = agent_callback - - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_before_invoke( - prompt_messages=prompt_messages - ) - - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - pass - - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_after_invoke( - result=result - ) - - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_error( - error=ex - ) diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 435074f7430f4c..bd896719c21835 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -3,6 +3,7 @@ from pydantic import BaseModel +from core.app.app_config.entities import FileUploadEntity from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db @@ -50,7 +51,7 @@ class FileObj(BaseModel): transfer_method: FileTransferMethod url: Optional[str] upload_file_id: Optional[str] - file_config: dict + file_upload_entity: FileUploadEntity @property def data(self) -> Optional[str]: @@ -63,7 +64,7 @@ def preview_url(self) -> Optional[str]: @property def prompt_message_content(self) -> ImagePromptMessageContent: if self.type == FileType.IMAGE: - image_config = self.file_config.get('image') + image_config = self.file_upload_entity.image_config return ImagePromptMessageContent( data=self.data, diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 1b7b8b87da7214..9d122c41204308 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Union import requests +from core.app.app_config.entities import FileUploadEntity from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType from extensions.ext_database import db from models.account import Account -from models.model import AppModelConfig, EndUser, MessageFile, UploadFile +from models.model import EndUser, MessageFile, UploadFile from services.file_service import IMAGE_EXTENSIONS @@ -15,18 +16,16 @@ def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, + def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity: FileUploadEntity, user: Union[Account, EndUser]) -> list[FileObj]: """ validate and transform files arg :param files: - :param app_model_config: + :param file_upload_entity: :param user: :return: """ - file_upload_config = app_model_config.file_upload_dict - for file in files: if not isinstance(file, dict): raise ValueError('Invalid file format, must be dict') @@ -45,17 +44,17 @@ def validate_and_transform_files_arg(self, files: list[dict], app_model_config: raise ValueError('Missing file upload_file_id') # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_config) + type_file_objs = self._to_file_objs(files, file_upload_entity) # validate files new_files = [] for file_type, file_objs in type_file_objs.items(): if file_type == FileType.IMAGE: # parse and validate files - image_config = file_upload_config.get('image') + image_config = file_upload_entity.image_config # check if image file feature is enabled - if not image_config['enabled']: + if not image_config: continue # Validate number of files @@ -96,27 +95,27 @@ def validate_and_transform_files_arg(self, files: list[dict], app_model_config: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: + def transform_message_files(self, files: list[MessageFile], file_upload_entity: FileUploadEntity) -> list[FileObj]: """ transform message files :param files: - :param app_model_config: + :param file_upload_entity: :return: """ # transform files to file objs - type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict) + type_file_objs = self._to_file_objs(files, file_upload_entity) # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_upload_config: dict) -> dict[FileType, list[FileObj]]: + file_upload_entity: FileUploadEntity) -> dict[FileType, list[FileObj]]: """ transform files to file objs :param files: - :param file_upload_config: + :param file_upload_entity: :return: """ type_file_objs: dict[FileType, list[FileObj]] = { @@ -133,7 +132,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], if file.belongs_to == FileBelongsTo.ASSISTANT.value: continue - file_obj = self._to_file_obj(file, file_upload_config) + file_obj = self._to_file_obj(file, file_upload_entity) if file_obj.type not in type_file_objs: continue @@ -141,7 +140,7 @@ def _to_file_objs(self, files: list[Union[dict, MessageFile]], return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj: + def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileUploadEntity) -> FileObj: """ transform file to file obj @@ -156,7 +155,7 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) transfer_method=transfer_method, url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - file_config=file_upload_config + file_upload_entity=file_upload_entity ) else: return FileObj( @@ -166,7 +165,7 @@ def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, upload_file_id=file.upload_file_id or None, - file_config=file_upload_config + file_upload_entity=file_upload_entity ) def _check_image_remote_url(self, url): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py new file mode 100644 index 00000000000000..9d74edee0e5248 --- /dev/null +++ b/api/core/helper/code_executor/code_executor.py @@ -0,0 +1,84 @@ +from os import environ +from typing import Literal, Optional + +from httpx import post +from pydantic import BaseModel +from yarl import URL + +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer +from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer +from core.helper.code_executor.python_transformer import PythonTemplateTransformer + +# Code Executor +CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '') +CODE_EXECUTION_API_KEY = environ.get('CODE_EXECUTION_API_KEY', '') + +class CodeExecutionException(Exception): + pass + +class CodeExecutionResponse(BaseModel): + class Data(BaseModel): + stdout: Optional[str] + error: Optional[str] + + code: int + message: str + data: Data + +class CodeExecutor: + @classmethod + def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + """ + Execute code + :param language: code language + :param code: code + :param inputs: inputs + :return: + """ + template_transformer = None + if language == 'python3': + template_transformer = PythonTemplateTransformer + elif language == 'jinja2': + template_transformer = Jinja2TemplateTransformer + elif language == 'javascript': + template_transformer = NodeJsTemplateTransformer + else: + raise CodeExecutionException('Unsupported language') + + runner = template_transformer.transform_caller(code, inputs) + url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + headers = { + 'X-Api-Key': CODE_EXECUTION_API_KEY + } + data = { + 'language': 'python3' if language == 'jinja2' else + 'nodejs' if language == 'javascript' else + 'python3' if language == 'python3' else None, + 'code': runner, + } + + try: + response = post(str(url), json=data, headers=headers) + if response.status_code == 503: + raise CodeExecutionException('Code execution service is unavailable') + elif response.status_code != 200: + raise Exception('Failed to execute code') + except CodeExecutionException as e: + raise e + except Exception as e: + raise CodeExecutionException('Failed to execute code') + + try: + response = response.json() + except: + raise CodeExecutionException('Failed to parse response') + + response = CodeExecutionResponse(**response) + + if response.code != 0: + raise CodeExecutionException(response.message) + + if response.data.error: + raise CodeExecutionException(response.data.error) + + return template_transformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py new file mode 100644 index 00000000000000..cc6ad16c66d833 --- /dev/null +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -0,0 +1,53 @@ +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +NODEJS_RUNNER = """// declare main function here +{{code}} + +// execute main function, and return the result +// inputs is a dict, unstructured inputs +output = main({{inputs}}) + +// convert output to json and print +output = JSON.stringify(output) + +result = `<>${output}<>` + +console.log(result) +""" + + +class NodeJsTemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = NODEJS_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + return json.loads(result) diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py new file mode 100644 index 00000000000000..87e8ce130f2820 --- /dev/null +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -0,0 +1,54 @@ +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +PYTHON_RUNNER = """ +import jinja2 + +template = jinja2.Template('''{{code}}''') + +def main(**inputs): + return template.render(**inputs) + +# execute main function, and return the result +output = main(**{{inputs}}) + +result = f'''<>{output}<>''' + +print(result) + +""" + +class Jinja2TemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform jinja2 template to python code + runner = PYTHON_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4)) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + + return { + 'result': result + } \ No newline at end of file diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py new file mode 100644 index 00000000000000..27863ee4435354 --- /dev/null +++ b/api/core/helper/code_executor/python_transformer.py @@ -0,0 +1,55 @@ +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +PYTHON_RUNNER = """# declare main function here +{{code}} + +# execute main function, and return the result +# inputs is a dict, and it +output = main(**{{inputs}}) + +# convert output to json and print +output = json.dumps(output, indent=4) + +result = f'''<> +{output} +<>''' + +print(result) +""" + + +class PythonTemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = PYTHON_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + return json.loads(result) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py new file mode 100644 index 00000000000000..5505df87493c02 --- /dev/null +++ b/api/core/helper/code_executor/template_transformer.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + + +class TemplateTransformer(ABC): + @classmethod + @abstractmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + pass + + @classmethod + @abstractmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + pass \ No newline at end of file diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 86d6b498da35e7..20feae8554f79d 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,7 +1,7 @@ import logging import random -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from extensions.ext_hosting_provider import hosting_configuration @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -def check_moderation(model_config: ModelConfigEntity, text: str) -> bool: +def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config if (moderation_config and moderation_config.enabled is True and 'openai' in hosting_configuration.provider_map diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0bfe763fac66bd..c44d4717e6da33 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -38,6 +38,10 @@ def patch(url, *args, **kwargs): return _patch(url=url, *args, proxies=httpx_proxies, **kwargs) def delete(url, *args, **kwargs): + if 'follow_redirects' in kwargs: + if kwargs['follow_redirects']: + kwargs['allow_redirects'] = kwargs['follow_redirects'] + kwargs.pop('follow_redirects') return _delete(url=url, *args, proxies=requests_proxies, **kwargs) def head(url, *args, **kwargs): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index dd46aa27dc612c..01a8ea3a5d0c6c 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/llm_generator/__init__.py b/api/core/llm_generator/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/generator/llm_generator.py b/api/core/llm_generator/llm_generator.py similarity index 93% rename from api/core/generator/llm_generator.py rename to api/core/llm_generator/llm_generator.py index 072b02dc94638a..1a6b71fb0ad20f 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -3,14 +3,14 @@ from langchain.schema import OutputParserException +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser -from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser -from core.prompt.prompt_template import PromptTemplateParser -from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT +from core.prompt.utils.prompt_template_parser import PromptTemplateParser class LLMGenerator: diff --git a/api/core/llm_generator/output_parser/__init__.py b/api/core/llm_generator/output_parser/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py similarity index 94% rename from api/core/prompt/output_parser/rule_config_generator.py rename to api/core/llm_generator/output_parser/rule_config_generator.py index 619555ce2e99f8..b95653f69c6eea 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -2,7 +2,7 @@ from langchain.schema import BaseOutputParser, OutputParserException -from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE +from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py similarity index 80% rename from api/core/prompt/output_parser/suggested_questions_after_answer.py rename to api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e37142ec9146c0..1b955c6edd2442 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -4,7 +4,8 @@ from langchain.schema import BaseOutputParser -from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.model_runtime.errors.invoke import InvokeError class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompts.py b/api/core/llm_generator/prompts.py similarity index 100% rename from api/core/prompt/prompts.py rename to api/core/llm_generator/prompts.py diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 4d44ac38183fb0..471400f09baffc 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,3 +1,4 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ( @@ -10,7 +11,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db -from models.model import Conversation, Message +from models.model import AppMode, Conversation, Message class TokenBufferMemory: @@ -43,9 +44,18 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, for message in messages: files = message.message_files if files: - file_objs = message_file_parser.transform_message_files( - files, message.app_model_config - ) + if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + file_upload_entity = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + else: + file_upload_entity = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict) + + if file_upload_entity: + file_objs = message_file_parser.transform_message_files( + files, + file_upload_entity + ) + else: + file_objs = [] if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index aa16cf866f9327..8c0633992767dc 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -24,11 +24,11 @@ class ModelInstance: """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: - self._provider_model_bundle = provider_model_bundle + self.provider_model_bundle = provider_model_bundle self.model = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) - self.model_type_instance = self._provider_model_bundle.model_type_instance + self.model_type_instance = self.provider_model_bundle.model_type_instance def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ diff --git a/api/core/features/moderation.py b/api/core/moderation/input_moderation.py similarity index 75% rename from api/core/features/moderation.py rename to api/core/moderation/input_moderation.py index a9d65f56e85c70..8fbc0c2d5003f6 100644 --- a/api/core/features/moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,31 +1,31 @@ import logging -from core.entities.application_entities import AppOrchestrationConfigEntity +from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationException from core.moderation.factory import ModerationFactory logger = logging.getLogger(__name__) -class ModerationFeature: +class InputModeration: def check(self, app_id: str, tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, + app_config: AppConfig, inputs: dict, query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity + :param app_config: app config :param inputs: inputs :param query: query :return: """ - if not app_orchestration_config_entity.sensitive_word_avoidance: + if not app_config.sensitive_word_avoidance: return False, inputs, query - sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance + sensitive_word_avoidance_config = app_config.sensitive_word_avoidance moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( diff --git a/api/core/app_runner/moderation_handler.py b/api/core/moderation/output_moderation.py similarity index 86% rename from api/core/app_runner/moderation_handler.py rename to api/core/moderation/output_moderation.py index b2098344c843ac..af8910614da0cd 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/moderation/output_moderation.py @@ -6,7 +6,8 @@ from flask import Flask, current_app from pydantic import BaseModel -from core.application_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import QueueMessageReplaceEvent from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -18,14 +19,14 @@ class ModerationRule(BaseModel): config: dict[str, Any] -class OutputModerationHandler(BaseModel): +class OutputModeration(BaseModel): DEFAULT_BUFFER_SIZE: int = 300 tenant_id: str app_id: str rule: ModerationRule - on_message_replace_func: Any + queue_manager: AppQueueManager thread: Optional[threading.Thread] = None thread_running: bool = True @@ -67,7 +68,12 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> final_output = result.text if public_event: - self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output + ), + PublishFrom.TASK_PIPELINE + ) return final_output @@ -117,7 +123,12 @@ def worker(self, flask_app: Flask, buffer_size: int): # trigger replace event if self.thread_running: - self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output + ), + PublishFrom.TASK_PIPELINE + ) if result.action == ModerationAction.DIRECT_OUTPUT: break diff --git a/api/core/prompt/__init__.py b/api/core/prompt/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py new file mode 100644 index 00000000000000..60c77e943b3322 --- /dev/null +++ b/api/core/prompt/advanced_prompt_transform.py @@ -0,0 +1,233 @@ +from typing import Optional, Union + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser + + +class AdvancedPromptTransform(PromptTransform): + """ + Advanced Prompt Transform for Workflow LLM Node. + """ + + def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + prompt_messages = [] + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + prompt_messages = self._get_completion_model_prompt_messages( + prompt_template=prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config + ) + elif model_mode == ModelMode.CHAT: + prompt_messages = self._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config + ) + + return prompt_messages + + def _get_completion_model_prompt_messages(self, + prompt_template: CompletionModelPromptTemplate, + inputs: dict, + query: Optional[str], + files: list[FileObj], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + """ + Get completion model prompt messages. + """ + raw_prompt = prompt_template.text + + prompt_messages = [] + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + + if memory and memory_config: + role_prefix = memory_config.role_prefix + prompt_inputs = self._set_histories_variable( + memory=memory, + memory_config=memory_config, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) + + if query: + prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs + ) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) + + return prompt_messages + + def _get_chat_model_prompt_messages(self, + prompt_template: list[ChatModelMessage], + inputs: dict, + query: Optional[str], + files: list[FileObj], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + """ + Get chat model prompt messages. + """ + raw_prompt_list = prompt_template + + prompt_messages = [] + + for prompt_item in raw_prompt_list: + raw_prompt = prompt_item.text + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs + ) + + if prompt_item.role == PromptMessageRole.USER: + prompt_messages.append(UserPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.ASSISTANT: + prompt_messages.append(AssistantPromptMessage(content=prompt)) + + if memory and memory_config: + prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + elif files: + if not query: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + elif query: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: + if '#context#' in prompt_template.variable_keys: + if context: + prompt_inputs['#context#'] = context + else: + prompt_inputs['#context#'] = '' + + return prompt_inputs + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: + if '#query#' in prompt_template.variable_keys: + if query: + prompt_inputs['#query#'] = query + else: + prompt_inputs['#query#'] = '' + + return prompt_inputs + + def _set_histories_variable(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, + raw_prompt: str, + role_prefix: MemoryConfig.RolePrefix, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigWithCredentialsEntity) -> dict: + if '#histories#' in prompt_template.variable_keys: + if memory: + inputs = {'#histories#': '', **prompt_inputs} + prompt_template = PromptTemplateParser(raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage( + content=prompt_template.format(prompt_inputs) + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + + histories = self._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config, + max_token_limit=rest_tokens, + human_prefix=role_prefix.user, + ai_prefix=role_prefix.assistant + ) + prompt_inputs['#histories#'] = histories + else: + prompt_inputs['#histories#'] = '' + + return prompt_inputs diff --git a/api/core/prompt/entities/__init__.py b/api/core/prompt/entities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py new file mode 100644 index 00000000000000..97ac2e3e2a8651 --- /dev/null +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -0,0 +1,42 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """ + Chat Message. + """ + text: str + role: PromptMessageRole + + +class CompletionModelPromptTemplate(BaseModel): + """ + Completion Model Prompt Template. + """ + text: str + + +class MemoryConfig(BaseModel): + """ + Memory Config. + """ + class RolePrefix(BaseModel): + """ + Role Prefix. + """ + user: str + assistant: str + + class WindowConfig(BaseModel): + """ + Window Config. + """ + enabled: bool + size: Optional[int] = None + + role_prefix: Optional[RolePrefix] = None + window: WindowConfig diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json deleted file mode 100644 index 709a8d88669d2d..00000000000000 --- a/api/core/prompt/generate_prompts/common_chat.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "human_prefix": "Human", - "assistant_prefix": "Assistant", - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", - "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{histories}}\n\n\n", - "system_prompt_orders": [ - "context_prompt", - "pre_prompt", - "histories_prompt" - ], - "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", - "stops": ["\nHuman:", ""] -} diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/generate_prompts/common_completion.json deleted file mode 100644 index 9e7e8d68ef333b..00000000000000 --- a/api/core/prompt/generate_prompts/common_completion.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", - "system_prompt_orders": [ - "context_prompt", - "pre_prompt" - ], - "query_prompt": "{{query}}", - "stops": null -} \ No newline at end of file diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py deleted file mode 100644 index 7727b0f92e83eb..00000000000000 --- a/api/core/prompt/prompt_builder.py +++ /dev/null @@ -1,10 +0,0 @@ -from core.prompt.prompt_template import PromptTemplateParser - - -class PromptBuilder: - @classmethod - def parse_prompt(cls, prompt: str, inputs: dict) -> str: - prompt_template = PromptTemplateParser(prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt = prompt_template.format(prompt_inputs) - return prompt diff --git a/api/core/prompt/prompt_templates/__init__.py b/api/core/prompt/prompt_templates/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py similarity index 100% rename from api/core/prompt/advanced_prompt_templates.py rename to api/core/prompt/prompt_templates/advanced_prompt_templates.py diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/prompt_templates/baichuan_chat.json similarity index 78% rename from api/core/prompt/generate_prompts/baichuan_chat.json rename to api/core/prompt/prompt_templates/baichuan_chat.json index 5bf83cd9c7634b..03b6a53cfff2d1 100644 --- a/api/core/prompt/generate_prompts/baichuan_chat.json +++ b/api/core/prompt/prompt_templates/baichuan_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "用户", "assistant_prefix": "助手", - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n", - "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n", + "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\n用户:{{query}}", + "query_prompt": "\n\n用户:{{#query#}}", "stops": ["用户:"] } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/prompt_templates/baichuan_completion.json similarity index 80% rename from api/core/prompt/generate_prompts/baichuan_completion.json rename to api/core/prompt/prompt_templates/baichuan_completion.json index a3a2054e830b90..ae8c0dac53392f 100644 --- a/api/core/prompt/generate_prompts/baichuan_completion.json +++ b/api/core/prompt/prompt_templates/baichuan_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/prompt_templates/common_chat.json b/api/core/prompt/prompt_templates/common_chat.json new file mode 100644 index 00000000000000..d398a512e670a7 --- /dev/null +++ b/api/core/prompt/prompt_templates/common_chat.json @@ -0,0 +1,13 @@ +{ + "human_prefix": "Human", + "assistant_prefix": "Assistant", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n", + "system_prompt_orders": [ + "context_prompt", + "pre_prompt", + "histories_prompt" + ], + "query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ", + "stops": ["\nHuman:", ""] +} diff --git a/api/core/prompt/prompt_templates/common_completion.json b/api/core/prompt/prompt_templates/common_completion.json new file mode 100644 index 00000000000000..c148772010fb05 --- /dev/null +++ b/api/core/prompt/prompt_templates/common_completion.json @@ -0,0 +1,9 @@ +{ + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "system_prompt_orders": [ + "context_prompt", + "pre_prompt" + ], + "query_prompt": "{{#query#}}", + "stops": null +} \ No newline at end of file diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 0a373b7c42f8bd..9bf2ae090f7686 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,419 +1,26 @@ -import enum -import json -import os -import re from typing import Optional, cast -from core.entities.application_entities import ( - AdvancedCompletionPromptTemplateEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.file.file_obj import FileObj +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) +from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import PromptTemplateParser - - -class AppMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - -class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'ModelMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') +from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: - def get_prompt(self, - app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name( - app_mode=app_mode, - provider=model_config.provider, - model=model_config.model - )) - - if app_mode == AppMode.CHAT and model_mode == ModelMode.CHAT: - stops = None - - prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - else: - stops = prompt_rules.get('stops') - if stops is not None and len(stops) == 0: - stops = None - - prompt_messages = self._get_simple_others_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - return prompt_messages, stops - - def get_advanced_prompt(self, app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_messages = [] - - if app_mode == AppMode.CHAT: - if model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_chat_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif model_mode == ModelMode.CHAT: - prompt_messages = self._get_chat_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif app_mode == AppMode.COMPLETION: - if model_mode == ModelMode.CHAT: - prompt_messages = self._get_completion_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - context=context, - ) - elif model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_completion_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - context=context, - ) - - return prompt_messages - - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: - """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } - - if human_prefix: - kwargs['human_prefix'] = human_prefix - - if ai_prefix: - kwargs['ai_prefix'] = ai_prefix - - return memory.get_history_prompt_text( - **kwargs - ) - - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> list[PromptMessage]: - """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit - ) - - def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: - # baichuan - if provider == 'baichuan': - return self._prompt_file_name_for_baichuan(app_mode) - - baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): - return self._prompt_file_name_for_baichuan(app_mode) - - # common - if app_mode == AppMode.COMPLETION: - return 'common_completion' - else: - return 'common_chat' - - def _prompt_file_name_for_baichuan(self, app_mode: AppMode) -> str: - if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' - else: - return 'baichuan_chat' - - def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: - # Get the absolute path of the subdirectory - prompt_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'generate_prompts') - - json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') - # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: - return json.load(json_file) - - def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - prompt_messages = [] - - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - - self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_simple_others_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - files: list[FileObj], - model_config: ModelConfigEntity) -> list[PromptMessage]: - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' - - if memory and 'histories_prompt' in prompt_rules: - # append chat histories - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=prompt + query_prompt, - inputs={ - 'query': query - } - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' - ) - prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) - histories_prompt_content = prompt_template.format({'histories': histories}) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' - elif order == 'histories_prompt': - prompt += histories_prompt_content - - prompt_template = PromptTemplateParser(template=query_prompt) - query_prompt_content = prompt_template.format({'query': query}) - - prompt += query_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - model_mode = ModelMode.value_of(model_config.mode) - - if model_mode == ModelMode.CHAT and files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - prompt_message = UserPromptMessage(content=prompt) - - return [prompt_message] - - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#context#' in prompt_template.variable_keys: - if context: - prompt_inputs['#context#'] = context - else: - prompt_inputs['#context#'] = '' - - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#query#' in prompt_template.variable_keys: - if query: - prompt_inputs['#query#'] = query - else: - prompt_inputs['#query#'] = '' - - def _set_histories_variable(self, memory: TokenBufferMemory, - raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: - if '#histories#' in prompt_template.variable_keys: - if memory: - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=raw_prompt, - inputs={'#histories#': '', **prompt_inputs} - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant - ) - prompt_inputs['#histories#'] = histories - else: - prompt_inputs['#histories#'] = '' - def _append_chat_histories(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, prompt_messages: list[PromptMessage], - model_config: ModelConfigEntity) -> None: - if memory: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) - prompt_messages.extend(histories) + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + rest_tokens = self._calculate_rest_token(prompt_messages, model_config) + histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) + prompt_messages.extend(histories) - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: + return prompt_messages + + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -439,152 +46,38 @@ def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_conf return rest_tokens - def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str: - prompt = prompt_template.format( - prompt_inputs - ) - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - return prompt - - def _get_chat_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - - prompt_messages = [] + def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None) -> str: + """Get memory messages.""" + kwargs = { + "max_token_limit": max_token_limit + } - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + if human_prefix: + kwargs['human_prefix'] = human_prefix - self._set_context_variable(context, prompt_template, prompt_inputs) + if ai_prefix: + kwargs['ai_prefix'] = ai_prefix - self._set_query_variable(query, prompt_template, prompt_inputs) + if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: + kwargs['message_limit'] = memory_config.window.size - self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config + return memory.get_history_prompt_text( + **kwargs ) - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_chat_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - self._append_chat_histories(memory, prompt_messages, model_config) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_completion_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - context: Optional[str]) -> list[PromptMessage]: - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_completion_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - files: list[FileObj], - context: Optional[str]) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - for prompt_message in prompt_messages[::-1]: - if prompt_message.role == PromptMessageRole.USER: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message.content = prompt_message_contents - break - - return prompt_messages + def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int) -> list[PromptMessage]: + """Get memory messages.""" + return memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if (memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0) + else 10 + ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py new file mode 100644 index 00000000000000..613716c2cf3b6c --- /dev/null +++ b/api/core/prompt/simple_prompt_transform.py @@ -0,0 +1,319 @@ +import enum +import json +import os +from typing import Optional + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from models.model import AppMode + + +class ModelMode(enum.Enum): + COMPLETION = 'completion' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'ModelMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + +prompt_file_contents = {} + + +class SimplePromptTransform(PromptTransform): + """ + Simple Prompt Transform for Chatbot App Basic Mode. + """ + + def get_prompt(self, + app_mode: AppMode, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) -> \ + tuple[list[PromptMessage], Optional[list[str]]]: + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.CHAT: + prompt_messages, stops = self._get_chat_model_prompt_messages( + app_mode=app_mode, + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + else: + prompt_messages, stops = self._get_completion_model_prompt_messages( + app_mode=app_mode, + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages, stops + + def get_prompt_str_and_rules(self, app_mode: AppMode, + model_config: ModelConfigWithCredentialsEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> tuple[str, dict]: + # get prompt template + prompt_template_config = self.get_prompt_template( + app_mode=app_mode, + provider=model_config.provider, + model=model_config.model, + pre_prompt=pre_prompt, + has_context=context is not None, + query_in_prompt=query is not None, + with_memory_prompt=histories is not None + ) + + variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + + for v in prompt_template_config['special_variable_keys']: + # support #context#, #query# and #histories# + if v == '#context#': + variables['#context#'] = context if context else '' + elif v == '#query#': + variables['#query#'] = query if query else '' + elif v == '#histories#': + variables['#histories#'] = histories if histories else '' + + prompt_template = prompt_template_config['prompt_template'] + prompt = prompt_template.format(variables) + + return prompt, prompt_template_config['prompt_rules'] + + def get_prompt_template(self, app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False) -> dict: + prompt_rules = self._get_prompt_rule( + app_mode=app_mode, + provider=provider, + model=model + ) + + custom_variable_keys = [] + special_variable_keys = [] + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt' and has_context: + prompt += prompt_rules['context_prompt'] + special_variable_keys.append('#context#') + elif order == 'pre_prompt' and pre_prompt: + prompt += pre_prompt + '\n' + pre_prompt_template = PromptTemplateParser(template=pre_prompt) + custom_variable_keys = pre_prompt_template.variable_keys + elif order == 'histories_prompt' and with_memory_prompt: + prompt += prompt_rules['histories_prompt'] + special_variable_keys.append('#histories#') + + if query_in_prompt: + prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' + special_variable_keys.append('#query#') + + return { + "prompt_template": PromptTemplateParser(template=prompt), + "custom_variable_keys": custom_variable_keys, + "special_variable_keys": special_variable_keys, + "prompt_rules": prompt_rules + } + + def _get_chat_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + prompt_messages = [] + + # get prompt + prompt, _ = self.get_prompt_str_and_rules( + app_mode=app_mode, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=None, + context=context + ) + + if prompt and query: + prompt_messages.append(SystemPromptMessage(content=prompt)) + + if memory: + prompt_messages = self._append_chat_histories( + memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), + prompt_messages=prompt_messages, + model_config=model_config + ) + + if query: + prompt_messages.append(self.get_last_user_message(query, files)) + else: + prompt_messages.append(self.get_last_user_message(prompt, files)) + + return prompt_messages, None + + def _get_completion_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=app_mode, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if memory: + tmp_human_message = UserPromptMessage( + content=prompt + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + histories = self._get_history_messages_from_memory( + memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), + max_token_limit=rest_tokens, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + ) + + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=app_mode, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context, + histories=histories + ) + + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + return [self.get_last_user_message(prompt, files)], stops + + def get_last_user_message(self, prompt: str, files: list[FileObj]) -> UserPromptMessage: + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_message = UserPromptMessage(content=prompt_message_contents) + else: + prompt_message = UserPromptMessage(content=prompt) + + return prompt_message + + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict: + """ + Get simple prompt rule. + :param app_mode: app mode + :param provider: model provider + :param model: model name + :return: + """ + prompt_file_name = self._prompt_file_name( + app_mode=app_mode, + provider=provider, + model=model + ) + + # Check if the prompt file is already loaded + if prompt_file_name in prompt_file_contents: + return prompt_file_contents[prompt_file_name] + + # Get the absolute path of the subdirectory + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') + json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + + # Open the JSON file and read its content + with open(json_file_path, encoding='utf-8') as json_file: + content = json.load(json_file) + + # Store the content of the prompt file + prompt_file_contents[prompt_file_name] = content + + return content + + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: + # baichuan + is_baichuan = False + if provider == 'baichuan': + is_baichuan = True + else: + baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] + if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + is_baichuan = True + + if is_baichuan: + if app_mode == AppMode.COMPLETION: + return 'baichuan_completion' + else: + return 'baichuan_chat' + + # common + if app_mode == AppMode.COMPLETION: + return 'common_completion' + else: + return 'common_chat' diff --git a/api/core/prompt/utils/__init__.py b/api/core/prompt/utils/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py new file mode 100644 index 00000000000000..5fceeb3595c9d7 --- /dev/null +++ b/api/core/prompt/utils/prompt_message_util.py @@ -0,0 +1,85 @@ +from typing import cast + +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) +from core.prompt.simple_prompt_transform import ModelMode + + +class PromptMessageUtil: + @staticmethod + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + """ + Prompt messages to prompt for saving. + :param model_mode: model mode + :param prompt_messages: prompt messages + :return: + """ + prompts = [] + if model_mode == ModelMode.CHAT.value: + for prompt_message in prompt_messages: + if prompt_message.role == PromptMessageRole.USER: + role = 'user' + elif prompt_message.role == PromptMessageRole.ASSISTANT: + role = 'assistant' + elif prompt_message.role == PromptMessageRole.SYSTEM: + role = 'system' + else: + continue + + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + prompts.append({ + "role": role, + "text": text, + "files": files + }) + else: + prompt_message = prompt_messages[0] + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + params = { + "role": 'user', + "text": text, + } + + if files: + params['files'] = files + + prompts.append(params) + + return prompts diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/utils/prompt_template_parser.py similarity index 93% rename from api/core/prompt/prompt_template.py rename to api/core/prompt/utils/prompt_template_parser.py index 32c5a791de4209..454f92e3b7dff5 100644 --- a/api/core/prompt/prompt_template.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -32,7 +32,8 @@ def replacer(match): return PromptTemplateParser.remove_template_variables(value) return value - return re.sub(REGEX, replacer, self.template) + prompt = re.sub(REGEX, replacer, self.template) + return re.sub(r'<\|.*?\|>', '', prompt) @classmethod def remove_template_variables(cls, text: str): diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6e28247d38d56a..0db84d3b6959a6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -235,7 +235,7 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D if available_models: found = False for available_model in available_models: - if available_model.model == "gpt-3.5-turbo-1106": + if available_model.model == "gpt-4": default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0d81c419d67ab1..139bfe15f328d6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -9,7 +9,7 @@ from flask import Flask, current_app from werkzeug.datastructures import FileStorage -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/core/rag/retrieval/__init__.py b/api/core/rag/retrieval/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/retrieval/agent/__init__.py b/api/core/rag/retrieval/agent/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/dataset_retrieval/agent/fake_llm.py b/api/core/rag/retrieval/agent/fake_llm.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/fake_llm.py rename to api/core/rag/retrieval/agent/fake_llm.py diff --git a/api/core/features/dataset_retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py similarity index 78% rename from api/core/features/dataset_retrieval/agent/llm_chain.py rename to api/core/rag/retrieval/agent/llm_chain.py index e5155e15a0849f..f2c5d4ca33042b 100644 --- a/api/core/features/dataset_retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -5,19 +5,17 @@ from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance +from core.rag.retrieval.agent.fake_llm import FakeLLM class LLMChain(LCLLMChain): - model_config: ModelConfigEntity + model_config: ModelConfigWithCredentialsEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") parameters: dict[str, Any] = {} - agent_llm_callback: Optional[AgentLLMCallback] = None def generate( self, @@ -38,7 +36,6 @@ def generate( prompt_messages=prompt_messages, stream=False, stop=stop, - callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None, model_parameters=self.parameters ) diff --git a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py similarity index 96% rename from api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 59923202fde47a..be24731d46a394 100644 --- a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -10,18 +10,18 @@ from langchain.tools import BaseTool from pydantic import root_validator -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool +from core.rag.retrieval.agent.fake_llm import FakeLLM class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ An Multi Dataset Retrieve Agent driven by Router. """ - model_config: ModelConfigEntity + model_config: ModelConfigWithCredentialsEntity class Config: """Configuration for this pydantic object.""" @@ -156,7 +156,7 @@ async def aplan( @classmethod def from_llm_and_tools( cls, - model_config: ModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, diff --git a/api/core/rag/retrieval/agent/output_parser/__init__.py b/api/core/rag/retrieval/agent/output_parser/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py b/api/core/rag/retrieval/agent/output_parser/structured_chat.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py rename to api/core/rag/retrieval/agent/output_parser/structured_chat.py diff --git a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py similarity index 98% rename from api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index e69302bfd682ea..7035ec8e2f7834 100644 --- a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,8 +12,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.entities.application_entities import ModelConfigEntity -from core.features.dataset_retrieval.agent.llm_chain import LLMChain +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. @@ -206,7 +206,7 @@ def _construct_scratchpad( @classmethod def from_llm_and_tools( cls, - model_config: ModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, diff --git a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py similarity index 87% rename from api/core/features/dataset_retrieval/agent_based_dataset_executor.py rename to api/core/rag/retrieval/agent_based_dataset_executor.py index 588ccc91f5f088..cb475bcffb7910 100644 --- a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -7,31 +7,29 @@ from langchain.tools import BaseTool from pydantic import BaseModel, Extra +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError +from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool class AgentConfiguration(BaseModel): strategy: PlanningStrategy - model_config: ModelConfigEntity + model_config: ModelConfigWithCredentialsEntity tools: list[BaseTool] - summary_model_config: Optional[ModelConfigEntity] = None + summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None memory: Optional[TokenBufferMemory] = None callbacks: Callbacks = None max_iterations: int = 6 max_execution_time: Optional[float] = None early_stopping_method: str = "generate" - agent_llm_callback: Optional[AgentLLMCallback] = None # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit class Config: diff --git a/api/core/features/dataset_retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py similarity index 95% rename from api/core/features/dataset_retrieval/dataset_retrieval.py rename to api/core/rag/retrieval/dataset_retrieval.py index 3e54d8644d3b46..ee728423262fce 100644 --- a/api/core/features/dataset_retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -2,22 +2,23 @@ from langchain.tools import BaseTool +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from models.dataset import Dataset -class DatasetRetrievalFeature: +class DatasetRetrieval: def retrieve(self, tenant_id: str, - model_config: ModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, config: DatasetEntity, query: str, invoke_from: InvokeFrom, diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 30128c4dcaeea8..1522d3af092cb6 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -2,9 +2,10 @@ from langchain.tools import BaseTool +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.tool.tool import Tool @@ -30,7 +31,7 @@ def get_dataset_tools(tenant_id: str, if retrieve_config is None: return [] - feature = DatasetRetrievalFeature() + feature = DatasetRetrieval() # save original retrieve strategy, and set retrieve strategy to SINGLE # Agent only support SINGLE mode diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2ac8f27bab7421..600b54f1c20ad9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,8 +5,8 @@ from os import listdir, path from typing import Any, Union +from core.agent.entities import AgentToolEntity from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.entities.application_entities import AgentToolEntity from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject @@ -34,6 +34,7 @@ ToolParameterConfigurationManager, ) from core.tools.utils.encoder import serialize_base_model_dict +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -225,6 +226,48 @@ def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, ten else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + @staticmethod + def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + """ + init runtime parameter + """ + parameter_value = parameters.get(parameter_rule.name) + if not parameter_value: + # get default value + parameter_value = parameter_rule.default + if not parameter_value and parameter_rule.required: + raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config") + + if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter_rule.options)) + if parameter_value not in options: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(parameter_value, int): + parameter_value = parameter_value + elif isinstance(parameter_value, float): + parameter_value = parameter_value + elif isinstance(parameter_value, str): + if '.' in parameter_value: + parameter_value = float(parameter_value) + else: + parameter_value = int(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: + parameter_value = bool(parameter_value) + elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + parameter_value = str(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType: + parameter_value = str(parameter_value) + except Exception as e: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type") + + return parameter_value + @staticmethod def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: """ @@ -239,44 +282,9 @@ def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_ca parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: - # get tool parameter from form - tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) - if not tool_parameter_config: - # get default value - tool_parameter_config = parameter.default - if not tool_parameter_config and parameter.required: - raise ValueError(f"tool parameter {parameter.name} not found in tool config") - - if parameter.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter.options)) - if tool_parameter_config not in options: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") - - # convert tool parameter config to correct type - try: - if parameter.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(tool_parameter_config, int): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, float): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, str): - if '.' in tool_parameter_config: - tool_parameter_config = float(tool_parameter_config) - else: - tool_parameter_config = int(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - tool_parameter_config = bool(tool_parameter_config) - elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: - tool_parameter_config = str(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType: - tool_parameter_config = str(tool_parameter_config) - except Exception as e: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") - # save tool parameter to tool entity memory - runtime_parameters[parameter.name] = tool_parameter_config + value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) + runtime_parameters[parameter.name] = value # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( @@ -289,6 +297,41 @@ def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_ca tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity + + @staticmethod + def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler): + """ + get the workflow tool runtime + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=workflow_tool.provider_type, + provider_name=workflow_tool.provider_id, + tool_name=workflow_tool.tool_name, + tenant_id=tenant_id, + agent_callback=agent_callback + ) + runtime_parameters = {} + parameters = tool_entity.get_all_runtime_parameters() + + for parameter in parameters: + # save tool parameter to tool entity memory + if parameter.form == ToolParameter.ToolParameterForm.FORM: + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=workflow_tool.provider_id, + provider_type=workflow_tool.provider_type, + ) + + if runtime_parameters: + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity @staticmethod def get_builtin_provider_icon(provider: str) -> tuple[str, str]: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py new file mode 100644 index 00000000000000..3f456b4eb6adbe --- /dev/null +++ b/api/core/tools/utils/message_transformer.py @@ -0,0 +1,85 @@ +import logging +from mimetypes import guess_extension + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_file_manager import ToolFileManager + +logger = logging.getLogger(__name__) + +class ToolFileMessageTransformer: + @staticmethod + def transform_tool_invoke_messages(messages: list[ToolInvokeMessage], + user_id: str, + tenant_id: str, + conversation_id: str) -> list[ToolInvokeMessage]: + """ + Transform tool message and handle file download + """ + result = [] + + for message in messages: + if message.type == ToolInvokeMessage.MessageType.TEXT: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.LINK: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.IMAGE: + # try to download image + try: + file = ToolFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_url=message.message + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' + + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + except Exception as e: + logger.exception(e) + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + )) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + mimetype = message.meta.get('mime_type', 'octet/stream') + # if message is str, encode it to bytes + if isinstance(message.message, str): + message.message = message.message.encode('utf-8') + + file = ToolFileManager.create_file_by_raw( + user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message, + mimetype=mimetype + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + + # check if file is image + if 'image' in mimetype: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(message) + + return result \ No newline at end of file diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py new file mode 100644 index 00000000000000..1f5472b430c96a --- /dev/null +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType + + +class BaseWorkflowCallback(ABC): + @abstractmethod + def on_workflow_run_started(self) -> None: + """ + Workflow run started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_run_succeeded(self) -> None: + """ + Workflow run succeeded + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: + """ + Workflow node execute started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: + """ + Workflow node execute succeeded + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: + """ + Workflow node execute failed + """ + raise NotImplementedError + + @abstractmethod + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + """ + Publish text chunk + """ + raise NotImplementedError diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py new file mode 100644 index 00000000000000..fc6ee231ffc7ea --- /dev/null +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -0,0 +1,9 @@ +from abc import ABC +from typing import Optional + +from pydantic import BaseModel + + +class BaseNodeData(ABC, BaseModel): + title: str + desc: Optional[str] = None diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py new file mode 100644 index 00000000000000..befabfb3b4e333 --- /dev/null +++ b/api/core/workflow/entities/node_entities.py @@ -0,0 +1,71 @@ +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + +from models.workflow import WorkflowNodeExecutionStatus + + +class NodeType(Enum): + """ + Node Types. + """ + START = 'start' + END = 'end' + ANSWER = 'answer' + LLM = 'llm' + KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' + IF_ELSE = 'if-else' + CODE = 'code' + TEMPLATE_TRANSFORM = 'template-transform' + QUESTION_CLASSIFIER = 'question-classifier' + HTTP_REQUEST = 'http-request' + TOOL = 'tool' + VARIABLE_ASSIGNER = 'variable-assigner' + + @classmethod + def value_of(cls, value: str) -> 'NodeType': + """ + Get value of given node type. + + :param value: node type value + :return: node type + """ + for node_type in cls: + if node_type.value == value: + return node_type + raise ValueError(f'invalid node type value {value}') + + +class SystemVariable(Enum): + """ + System Variables. + """ + QUERY = 'query' + FILES = 'files' + CONVERSATION = 'conversation' + + +class NodeRunMetadataKey(Enum): + """ + Node Run Metadata Key. + """ + TOTAL_TOKENS = 'total_tokens' + TOTAL_PRICE = 'total_price' + CURRENCY = 'currency' + + +class NodeRunResult(BaseModel): + """ + Node Run Result. + """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING + + inputs: Optional[dict] = None # node inputs + process_data: Optional[dict] = None # process data + outputs: Optional[dict] = None # node outputs + metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + + edge_source_handle: Optional[str] = None # source handle id of node with multiple branches + + error: Optional[str] = None # error message if status is failed diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py new file mode 100644 index 00000000000000..19d9af2a6171a4 --- /dev/null +++ b/api/core/workflow/entities/variable_entities.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class VariableSelector(BaseModel): + """ + Variable Selector. + """ + variable: str + value_selector: list[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py new file mode 100644 index 00000000000000..ff96bc3bac0276 --- /dev/null +++ b/api/core/workflow/entities/variable_pool.py @@ -0,0 +1,91 @@ +from enum import Enum +from typing import Any, Optional, Union + +from core.workflow.entities.node_entities import SystemVariable + +VariableValue = Union[str, int, float, dict, list] + + +class ValueType(Enum): + """ + Value Type Enum + """ + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" + FILE = "file" + + +class VariablePool: + variables_mapping = {} + user_inputs: dict + + def __init__(self, system_variables: dict[SystemVariable, Any], + user_inputs: dict) -> None: + # system variables + # for example: + # { + # 'query': 'abc', + # 'files': [] + # } + self.user_inputs = user_inputs + for system_variable, value in system_variables.items(): + self.append_variable('sys', [system_variable.value], value) + + def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: + """ + Append variable + :param node_id: node id + :param variable_key_list: variable key list, like: ['result', 'text'] + :param value: value + :return: + """ + if node_id not in self.variables_mapping: + self.variables_mapping[node_id] = {} + + variable_key_list_hash = hash(tuple(variable_key_list)) + + self.variables_mapping[node_id][variable_key_list_hash] = value + + def get_variable_value(self, variable_selector: list[str], + target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]: + """ + Get variable + :param variable_selector: include node_id and variables + :param target_value_type: target value type + :return: + """ + if len(variable_selector) < 2: + raise ValueError('Invalid value selector') + + node_id = variable_selector[0] + if node_id not in self.variables_mapping: + return None + + # fetch variable keys, pop node_id + variable_key_list = variable_selector[1:] + + variable_key_list_hash = hash(tuple(variable_key_list)) + + value = self.variables_mapping[node_id].get(variable_key_list_hash) + + if target_value_type: + if target_value_type == ValueType.STRING: + return str(value) + elif target_value_type == ValueType.NUMBER: + return int(value) + elif target_value_type == ValueType.OBJECT: + if not isinstance(value, dict): + raise ValueError('Invalid value type: object') + elif target_value_type in [ValueType.ARRAY_STRING, + ValueType.ARRAY_NUMBER, + ValueType.ARRAY_OBJECT, + ValueType.ARRAY_FILE]: + if not isinstance(value, list): + raise ValueError(f'Invalid value type: {target_value_type.value}') + + return value diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py new file mode 100644 index 00000000000000..a78bf09a531c14 --- /dev/null +++ b/api/core/workflow/entities/workflow_entities.py @@ -0,0 +1,46 @@ +from typing import Optional + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode, UserFrom +from models.workflow import Workflow, WorkflowType + + +class WorkflowNodeAndResult: + node: BaseNode + result: Optional[NodeRunResult] = None + + def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): + self.node = node + self.result = result + + +class WorkflowRunState: + tenant_id: str + app_id: str + workflow_id: str + workflow_type: WorkflowType + user_id: str + user_from: UserFrom + + start_at: float + variable_pool: VariablePool + + total_tokens: int = 0 + + workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] + + def __init__(self, workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom): + self.workflow_id = workflow.id + self.tenant_id = workflow.tenant_id + self.app_id = workflow.app_id + self.workflow_type = WorkflowType.value_of(workflow.type) + self.user_id = user_id + self.user_from = user_from + + self.start_at = start_at + self.variable_pool = variable_pool diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py new file mode 100644 index 00000000000000..fe79fadf66b876 --- /dev/null +++ b/api/core/workflow/errors.py @@ -0,0 +1,10 @@ +from core.workflow.entities.node_entities import NodeType + + +class WorkflowNodeRunFailedError(Exception): + def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): + self.node_id = node_id + self.node_type = node_type + self.node_title = node_title + self.error = error + super().__init__(f"Node {node_title} run failed: {error}") diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py new file mode 100644 index 00000000000000..d8ff5cb6f630d1 --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -0,0 +1,152 @@ +from typing import cast + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.nodes.base_node import BaseNode +from models.workflow import WorkflowNodeExecutionStatus + + +class AnswerNode(BaseNode): + _node_data_cls = AnswerNodeData + node_type = NodeType.ANSWER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + # generate routes + generate_routes = self.extract_generate_route_from_node_data(node_data) + + answer = [] + for part in generate_routes: + if part.type == "var": + part = cast(VarGenerateRouteChunk, part) + value_selector = part.value_selector + value = variable_pool.get_variable_value( + variable_selector=value_selector, + target_value_type=ValueType.STRING + ) + + answer_part = { + "type": "text", + "text": value + } + # TODO File + else: + part = cast(TextGenerateRouteChunk, part) + answer_part = { + "type": "text", + "text": part.text + } + + if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": + answer[-1]["text"] += answer_part["text"] + else: + answer.append(answer_part) + + if len(answer) == 1 and answer[0]["type"] == "text": + answer = answer[0]["text"] + + # re-fetch variable values + variable_values = {} + for variable_selector in node_data.variables: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + variable_values[variable_selector.variable] = value + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variable_values, + outputs={ + "answer": answer + } + ) + + @classmethod + def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + node_data = cast(cls._node_data_cls, node_data) + + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(node_data.answer) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + node_data = cast(cls._node_data_cls, node_data) + + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + return variable_mapping diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py new file mode 100644 index 00000000000000..8aed752ccb55e6 --- /dev/null +++ b/api/core/workflow/nodes/answer/entities.py @@ -0,0 +1,36 @@ + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class AnswerNodeData(BaseNodeData): + """ + Answer Node Data. + """ + variables: list[VariableSelector] = [] + answer: str + + +class GenerateRouteChunk(BaseModel): + """ + Generate Route Chunk. + """ + type: str + + +class VarGenerateRouteChunk(GenerateRouteChunk): + """ + Var Generate Route Chunk. + """ + type: str = "var" + value_selector: list[str] + + +class TextGenerateRouteChunk(GenerateRouteChunk): + """ + Text Generate Route Chunk. + """ + type: str = "text" + text: str diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py new file mode 100644 index 00000000000000..7cc9c6ee3dba81 --- /dev/null +++ b/api/core/workflow/nodes/base_node.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional + +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool + + +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") + + +class BaseNode(ABC): + _node_data_cls: type[BaseNodeData] + _node_type: NodeType + + tenant_id: str + app_id: str + workflow_id: str + user_id: str + user_from: UserFrom + + node_id: str + node_data: BaseNodeData + node_run_result: Optional[NodeRunResult] = None + + callbacks: list[BaseWorkflowCallback] + + def __init__(self, tenant_id: str, + app_id: str, + workflow_id: str, + user_id: str, + user_from: UserFrom, + config: dict, + callbacks: list[BaseWorkflowCallback] = None) -> None: + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.user_id = user_id + self.user_from = user_from + + self.node_id = config.get("id") + if not self.node_id: + raise ValueError("Node ID is required.") + + self.node_data = self._node_data_cls(**config.get("data", {})) + self.callbacks = callbacks or [] + + @abstractmethod + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + raise NotImplementedError + + def run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node entry + :param variable_pool: variable pool + :return: + """ + result = self._run( + variable_pool=variable_pool + ) + + self.node_run_result = result + return result + + def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: + """ + Publish text chunk + :param text: chunk text + :param value_selector: value selector + :return: + """ + if self.callbacks: + for callback in self.callbacks: + callback.on_node_text_chunk( + node_id=self.node_id, + text=text, + metadata={ + "node_type": self.node_type, + "value_selector": value_selector + } + ) + + @classmethod + def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + return cls._extract_variable_selector_to_variable_mapping(node_data) + + @classmethod + @abstractmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + raise NotImplementedError + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return {} + + @property + def node_type(self) -> NodeType: + """ + Get node type + :return: + """ + return self._node_type diff --git a/api/core/workflow/nodes/code/__init__.py b/api/core/workflow/nodes/code/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py new file mode 100644 index 00000000000000..0b46f86e9d22e7 --- /dev/null +++ b/api/core/workflow/nodes/code/code_node.py @@ -0,0 +1,308 @@ +from typing import Optional, Union, cast + +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.code.entities import CodeNodeData +from models.workflow import WorkflowNodeExecutionStatus + +MAX_NUMBER = 2 ** 63 - 1 +MIN_NUMBER = -2 ** 63 +MAX_PRECISION = 20 +MAX_DEPTH = 5 +MAX_STRING_LENGTH = 1000 +MAX_STRING_ARRAY_LENGTH = 30 +MAX_NUMBER_ARRAY_LENGTH = 1000 + +JAVASCRIPT_DEFAULT_CODE = """function main({args1, args2}) { + return { + result: args1 + args2 + } +}""" + +PYTHON_DEFAULT_CODE = """def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + }""" + +class CodeNode(BaseNode): + _node_data_cls = CodeNodeData + node_type = NodeType.CODE + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + if filters and filters.get("code_language") == "javascript": + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "javascript", + "code": JAVASCRIPT_DEFAULT_CODE, + "outputs": { + "result": { + "type": "number", + "children": None + } + } + } + } + + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "python3", + "code": PYTHON_DEFAULT_CODE, + "outputs": { + "result": { + "type": "number", + "children": None + } + } + } + } + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run code + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data: CodeNodeData = cast(self._node_data_cls, node_data) + + # Get code language + code_language = node_data.code_language + code = node_data.code + + # Get variables + variables = {} + for variable_selector in node_data.variables: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + variables[variable] = value + # Run code + try: + result = CodeExecutor.execute_code( + language=code_language, + code=code, + inputs=variables + ) + + # Transform result + result = self._transform_result(result, node_data.outputs) + except (CodeExecutionException, ValueError) as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e) + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=result + ) + + def _check_string(self, value: str, variable: str) -> str: + """ + Check string + :param value: value + :param variable: variable + :return: + """ + if not isinstance(value, str): + raise ValueError(f"{variable} in output form must be a string") + + if len(value) > MAX_STRING_LENGTH: + raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters') + + return value.replace('\x00', '') + + def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + """ + Check number + :param value: value + :param variable: variable + :return: + """ + if not isinstance(value, int | float): + raise ValueError(f"{variable} in output form must be a number") + + if value > MAX_NUMBER or value < MIN_NUMBER: + raise ValueError(f'{variable} in input form is out of range.') + + if isinstance(value, float): + # raise error if precision is too high + if len(str(value).split('.')[1]) > MAX_PRECISION: + raise ValueError(f'{variable} in output form has too high precision.') + + return value + + def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = '', + depth: int = 1) -> dict: + """ + Transform result + :param result: result + :param output_schema: output schema + :return: + """ + if depth > MAX_DEPTH: + raise ValueError("Depth limit reached, object too deep.") + + transformed_result = {} + if output_schema is None: + # validate output thought instance type + for output_name, output_value in result.items(): + if isinstance(output_value, dict): + self._transform_result( + result=output_value, + output_schema=None, + prefix=f'{prefix}.{output_name}' if prefix else output_name, + depth=depth + 1 + ) + elif isinstance(output_value, int | float): + self._check_number( + value=output_value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif isinstance(output_value, str): + self._check_string( + value=output_value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif isinstance(output_value, list): + if all(isinstance(value, int | float) for value in output_value): + for value in output_value: + self._check_number( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif all(isinstance(value, str) for value in output_value): + for value in output_value: + self._check_string( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + else: + raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + else: + raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') + + return result + + parameters_validated = {} + for output_name, output_config in output_schema.items(): + if output_config.type == 'object': + # check if output is object + if not isinstance(result.get(output_name), dict): + raise ValueError( + f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' + ) + + transformed_result[output_name] = self._transform_result( + result=result[output_name], + output_schema=output_config.children, + prefix=f'{prefix}.{output_name}' if prefix else output_name, + depth=depth + 1 + ) + elif output_config.type == 'number': + # check if number available + transformed_result[output_name] = self._check_number( + value=result[output_name], + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif output_config.type == 'string': + # check if string available + transformed_result[output_name] = self._check_string( + value=result[output_name], + variable=f'{prefix}.{output_name}' if prefix else output_name, + ) + elif output_config.type == 'array[number]': + # check if array of number available + if not isinstance(result[output_name], list): + raise ValueError( + f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + + if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: + raise ValueError( + f'{prefix}.{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' + ) + + transformed_result[output_name] = [ + self._check_number( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + for value in result[output_name] + ] + elif output_config.type == 'array[string]': + # check if array of string available + if not isinstance(result[output_name], list): + raise ValueError( + f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + + if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: + raise ValueError( + f'{prefix}.{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters' + ) + + transformed_result[output_name] = [ + self._check_string( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + for value in result[output_name] + ] + else: + raise ValueError(f'Output type {output_config.type} is not supported.') + + parameters_validated[output_name] = True + + # check if all output parameters are validated + if len(parameters_validated) != len(result): + raise ValueError('Not all output parameters are validated.') + + return transformed_result + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py new file mode 100644 index 00000000000000..97e178f5df9112 --- /dev/null +++ b/api/core/workflow/nodes/code/entities.py @@ -0,0 +1,20 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class CodeNodeData(BaseNodeData): + """ + Code Node Data. + """ + class Output(BaseModel): + type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] + children: Optional[dict[str, 'Output']] + + variables: list[VariableSelector] + code_language: Literal['python3', 'javascript'] + code: str + outputs: dict[str, Output] \ No newline at end of file diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py new file mode 100644 index 00000000000000..3241860c298a93 --- /dev/null +++ b/api/core/workflow/nodes/end/end_node.py @@ -0,0 +1,45 @@ +from typing import cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.entities import EndNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class EndNode(BaseNode): + _node_data_cls = EndNodeData + node_type = NodeType.END + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + output_variables = node_data.outputs + + outputs = {} + for variable_selector in output_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=outputs, + outputs=outputs + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py new file mode 100644 index 00000000000000..ad4fc8f04fd43c --- /dev/null +++ b/api/core/workflow/nodes/end/entities.py @@ -0,0 +1,9 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class EndNodeData(BaseNodeData): + """ + END Node Data. + """ + outputs: list[VariableSelector] diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py new file mode 100644 index 00000000000000..0683008954de26 --- /dev/null +++ b/api/core/workflow/nodes/http_request/entities.py @@ -0,0 +1,45 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, validator + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class HttpRequestNodeData(BaseNodeData): + """ + Code Node Data. + """ + class Authorization(BaseModel): + class Config(BaseModel): + type: Literal[None, 'basic', 'bearer', 'custom'] + api_key: Union[None, str] + header: Union[None, str] + + type: Literal['no-auth', 'api-key'] + config: Optional[Config] + + @validator('config', always=True, pre=True) + def check_config(cls, v, values): + """ + Check config, if type is no-auth, config should be None, otherwise it should be a dict. + """ + if values['type'] == 'no-auth': + return None + else: + if not v or not isinstance(v, dict): + raise ValueError('config should be a dict') + + return v + + class Body(BaseModel): + type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] + data: Union[None, str] + + variables: list[VariableSelector] + method: Literal['get', 'post', 'put', 'patch', 'delete'] + url: str + authorization: Authorization + headers: str + params: str + body: Optional[Body] \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py new file mode 100644 index 00000000000000..3d307be0d1f58b --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -0,0 +1,272 @@ +import re +from copy import deepcopy +from typing import Any, Union +from urllib.parse import urlencode + +import httpx +import requests + +import core.helper.ssrf_proxy as ssrf_proxy +from core.workflow.nodes.http_request.entities import HttpRequestNodeData + +HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60) + +class HttpExecutorResponse: + status_code: int + headers: dict[str, str] + body: str + + def __init__(self, status_code: int, headers: dict[str, str], body: str): + """ + init + """ + self.status_code = status_code + self.headers = headers + self.body = body + +class HttpExecutor: + server_url: str + method: str + authorization: HttpRequestNodeData.Authorization + params: dict[str, Any] + headers: dict[str, Any] + body: Union[None, str] + files: Union[None, dict[str, Any]] + + def __init__(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + """ + init + """ + self.server_url = node_data.url + self.method = node_data.method + self.authorization = node_data.authorization + self.params = {} + self.headers = {} + self.body = None + self.files = None + + # init template + self._init_template(node_data, variables) + + def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + """ + init template + """ + # extract all template in url + url_template = re.findall(r'{{(.*?)}}', node_data.url) or [] + url_template = list(set(url_template)) + original_url = node_data.url + for url in url_template: + if not url: + continue + + original_url = original_url.replace(f'{{{{{url}}}}}', str(variables.get(url, ''))) + + self.server_url = original_url + + # extract all template in params + param_template = re.findall(r'{{(.*?)}}', node_data.params) or [] + param_template = list(set(param_template)) + original_params = node_data.params + for param in param_template: + if not param: + continue + + original_params = original_params.replace(f'{{{{{param}}}}}', str(variables.get(param, ''))) + + # fill in params + kv_paris = original_params.split('\n') + for kv in kv_paris: + if not kv.strip(): + continue + + kv = kv.split(':') + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: + raise ValueError(f'Invalid params {kv}') + + self.params[k] = v + + # extract all template in headers + header_template = re.findall(r'{{(.*?)}}', node_data.headers) or [] + header_template = list(set(header_template)) + original_headers = node_data.headers + for header in header_template: + if not header: + continue + + original_headers = original_headers.replace(f'{{{{{header}}}}}', str(variables.get(header, ''))) + + # fill in headers + kv_paris = original_headers.split('\n') + for kv in kv_paris: + if not kv.strip(): + continue + + kv = kv.split(':') + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: + raise ValueError(f'Invalid headers {kv}') + + self.headers[k] = v + + # extract all template in body + if node_data.body: + body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] + body_template = list(set(body_template)) + original_body = node_data.body.data or '' + for body in body_template: + if not body: + continue + + original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) + + if node_data.body.type == 'json': + self.headers['Content-Type'] = 'application/json' + elif node_data.body.type == 'x-www-form-urlencoded': + self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + + if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + body = {} + kv_paris = original_body.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) == 2: + body[kv[0]] = kv[1] + elif len(kv) == 1: + body[kv[0]] = '' + else: + raise ValueError(f'Invalid body {kv}') + + if node_data.body.type == 'form-data': + self.files = { + k: ('', v) for k, v in body.items() + } + else: + self.body = urlencode(body) + elif node_data.body.type in ['json', 'raw']: + self.body = original_body + elif node_data.body.type == 'none': + self.body = '' + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.authorization) + headers = deepcopy(self.headers) or {} + if self.authorization.type == 'api-key': + if self.authorization.config.api_key is None: + raise ValueError('api_key is required') + + if not self.authorization.config.header: + authorization.config.header = 'Authorization' + + if self.authorization.config.type == 'bearer': + headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' + elif self.authorization.config.type == 'basic': + headers[authorization.config.header] = f'Basic {authorization.config.api_key}' + elif self.authorization.config.type == 'custom': + headers[authorization.config.header] = authorization.config.api_key + + return headers + + def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse: + """ + validate the response + """ + if isinstance(response, httpx.Response): + # get key-value pairs headers + headers = {} + for k, v in response.headers.items(): + headers[k] = v + + return HttpExecutorResponse(response.status_code, headers, response.text) + elif isinstance(response, requests.Response): + # get key-value pairs headers + headers = {} + for k, v in response.headers.items(): + headers[k] = v + + return HttpExecutorResponse(response.status_code, headers, response.text) + else: + raise ValueError(f'Invalid response type {type(response)}') + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + # do http request + kwargs = { + 'url': self.server_url, + 'headers': headers, + 'params': self.params, + 'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT, + 'follow_redirects': True + } + + if self.method == 'get': + response = ssrf_proxy.get(**kwargs) + elif self.method == 'post': + response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs) + elif self.method == 'put': + response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs) + elif self.method == 'delete': + response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs) + elif self.method == 'patch': + response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs) + elif self.method == 'head': + response = ssrf_proxy.head(**kwargs) + elif self.method == 'options': + response = ssrf_proxy.options(**kwargs) + else: + raise ValueError(f'Invalid http method {self.method}') + + return response + + def invoke(self) -> HttpExecutorResponse: + """ + invoke http request + """ + # assemble headers + headers = self._assembling_headers() + + # do http request + response = self._do_http_request(headers) + + # validate response + return self._validate_and_parse_response(response) + + def to_raw_request(self) -> str: + """ + convert to raw request + """ + server_url = self.server_url + if self.params: + server_url += f'?{urlencode(self.params)}' + + raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + + headers = self._assembling_headers() + for k, v in headers.items(): + raw_request += f'{k}: {v}\n' + + raw_request += '\n' + + # if files, use multipart/form-data with boundary + if self.files: + boundary = '----WebKitFormBoundary7MA4YWxkTrZu0gW' + raw_request = f'--{boundary}\n' + raw_request + for k, v in self.files.items(): + raw_request += f'Content-Disposition: form-data; name="{k}"; filename="{v[0]}"\n' + raw_request += f'Content-Type: {v[1]}\n\n' + raw_request += v[1] + '\n' + raw_request += f'--{boundary}\n' + raw_request += '--\n' + else: + raw_request += self.body or '' + + return raw_request \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py new file mode 100644 index 00000000000000..a914ae13ff1b0b --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -0,0 +1,63 @@ +from typing import cast + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.nodes.http_request.http_executor import HttpExecutor +from models.workflow import WorkflowNodeExecutionStatus + + +class HttpRequestNode(BaseNode): + _node_data_cls = HttpRequestNodeData + node_type = NodeType.HTTP_REQUEST + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value(variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # init http executor + try: + http_executor = HttpExecutor(node_data=node_data, variables=variables) + + # invoke http executor + response = http_executor.invoke() + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + process_data={ + 'request': http_executor.to_raw_request() + } + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs={ + 'status_code': response.status_code, + 'body': response.body, + 'headers': response.headers + }, + process_data={ + 'request': http_executor.to_raw_request(), + } + ) + + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/core/workflow/nodes/if_else/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py new file mode 100644 index 00000000000000..68d51c93bee1ac --- /dev/null +++ b/api/core/workflow/nodes/if_else/entities.py @@ -0,0 +1,26 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class IfElseNodeData(BaseNodeData): + """ + Answer Node Data. + """ + class Condition(BaseModel): + """ + Condition entity + """ + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + # for number + "=", "≠", ">", "<", "≥", "≤", "null", "not null" + ] + value: Optional[str] = None + + logical_operator: Literal["and", "or"] = "and" + conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py new file mode 100644 index 00000000000000..44a4091a2efc6e --- /dev/null +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -0,0 +1,398 @@ +from typing import Optional, cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.if_else.entities import IfElseNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class IfElseNode(BaseNode): + _node_data_cls = IfElseNodeData + node_type = NodeType.IF_ELSE + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + node_inputs = { + "conditions": [] + } + + process_datas = { + "condition_results": [] + } + + try: + logical_operator = node_data.logical_operator + input_conditions = [] + for condition in node_data.conditions: + actual_value = variable_pool.get_variable_value( + variable_selector=condition.variable_selector + ) + + expected_value = condition.value + + input_conditions.append({ + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator + }) + + node_inputs["conditions"] = input_conditions + + for input_condition in input_conditions: + actual_value = input_condition["actual_value"] + expected_value = input_condition["expected_value"] + comparison_operator = input_condition["comparison_operator"] + + if comparison_operator == "contains": + compare_result = self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + compare_result = self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + compare_result = self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + compare_result = self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + compare_result = self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + compare_result = self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + compare_result = self._assert_empty(actual_value) + elif comparison_operator == "not empty": + compare_result = self._assert_not_empty(actual_value) + elif comparison_operator == "=": + compare_result = self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + compare_result = self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + compare_result = self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + compare_result = self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + compare_result = self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + compare_result = self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + compare_result = self._assert_null(actual_value) + elif comparison_operator == "not null": + compare_result = self._assert_not_null(actual_value) + else: + continue + + process_datas["condition_results"].append({ + **input_condition, + "result": compare_result + }) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=node_inputs, + process_data=process_datas, + error=str(e) + ) + + if logical_operator == "and": + compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]] + else: + compare_result = True in [condition["result"] for condition in process_datas["condition_results"]] + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_datas, + edge_source_handle="false" if not compare_result else "true", + outputs={ + "result": compare_result + } + ) + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py new file mode 100644 index 00000000000000..d9934acff9c619 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py @@ -0,0 +1,194 @@ +import threading +from typing import Optional + +from flask import Flask, current_app +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(BaseTool): + """Tool for querying multi dataset.""" + name: str = "dataset-" + args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + tenant_id: str + dataset_ids: list[str] + top_k: int = 2 + score_threshold: Optional[float] = None + reranking_provider_name: str + reranking_model_name: str + return_resource: bool + retriever_from: str + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + + @classmethod + def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): + return cls( + name=f'dataset-{tenant_id}', + tenant_id=tenant_id, + dataset_ids=dataset_ids, + **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'all_documents': all_documents, + 'hit_callbacks': self.hit_callbacks + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.reranking_provider_name, + model_type=ModelType.RERANK, + model=self.reranking_model_name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(all_documents) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + } + + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler]): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + for hit_callback in hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + if documents: + all_documents.extend(documents) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py new file mode 100644 index 00000000000000..13331d981bbecf --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py @@ -0,0 +1,159 @@ +from typing import Optional + +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.datasource.retrieval_service import RetrievalService +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetRetrieverToolInput(BaseModel): + query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") + + +class DatasetRetrieverTool(BaseTool): + """Tool for querying a Dataset.""" + name: str = "dataset" + args_schema: type[BaseModel] = DatasetRetrieverToolInput + description: str = "use this to retrieve a dataset. " + + tenant_id: str + dataset_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs): + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + return cls( + name=f'dataset-{dataset.id}', + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + description=description, + **kwargs + ) + + def _run(self, query: str) -> str: + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == self.dataset_id + ).first() + + if not dataset: + return '' + + for hit_callback in self.hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in documents] + segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + + } + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py new file mode 100644 index 00000000000000..905ee1f80da0e3 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -0,0 +1,52 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + provider: str + mode: str + + +class MultipleRetrievalConfig(BaseModel): + """ + Multiple Retrieval Config. + """ + top_k: int + score_threshold: Optional[float] + reranking_model: RerankingModelConfig + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class SingleRetrievalConfig(BaseModel): + """ + Single Retrieval Config. + """ + model: ModelConfig + + +class KnowledgeRetrievalNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + variables: list[VariableSelector] + dataset_ids: list[str] + retrieval_mode: Literal['single', 'multiple'] + multiple_retrieval_config: MultipleRetrievalConfig + singleRetrievalConfig: SingleRetrievalConfig diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py new file mode 100644 index 00000000000000..a501113dc313b8 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -0,0 +1,373 @@ +import threading +from typing import cast, Any + +from flask import current_app, Flask + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment, Document +from models.workflow import WorkflowNodeExecutionStatus + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class KnowledgeRetrievalNode(BaseNode): + _node_data_cls = KnowledgeRetrievalNodeData + _node_type = NodeType.KNOWLEDGE_RETRIEVAL + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # retrieve knowledge + try: + outputs = self._fetch_dataset_retriever( + node_data=node_data, variables=variables + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs + ) + + except Exception as e: + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e) + ) + + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[ + dict[str, Any]]: + """ + A dataset tool is a tool that can be used to retrieve information from a dataset + :param node_data: node data + :param variables: variables + """ + tools = [] + available_datasets = [] + dataset_ids = node_data.dataset_ids + for dataset_id in dataset_ids: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + # pass if dataset is not available + if not dataset: + continue + + # pass if dataset is not available + if (dataset and dataset.available_document_count == 0 + and dataset.available_document_count == 0): + continue + + available_datasets.append(dataset) + all_documents = [] + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + all_documents = self._single_retrieve(available_datasets, node_data, variables) + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + all_documents = self._multiple_retrieve(available_datasets, node_data, variables) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + context_list = [] + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + + source = { + 'metadata': { + '_source': 'knowledge', + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'document_data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': 'workflow', + 'score': document_score_list.get(segment.index_node_id, None), + 'segment_hit_count': segment.hit_count, + 'segment_word_count': segment.word_count, + 'segment_position': segment.position + } + } + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + + return context_list + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } + + def _single_retrieve(self, available_datasets, node_data, variables): + tools = [] + for dataset in available_datasets: + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + message_tool = PromptMessageTool( + name=dataset.id, + description=description, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + tools.append(message_tool) + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + prompt_messages = [ + SystemPromptMessage(content='You are a helpful AI assistant.'), + UserPromptMessage(content=variables['#query#']) + ] + result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=tools, + stream=False, + model_parameters={ + 'temperature': 0.2, + 'top_p': 0.3, + 'max_tokens': 1500 + } + ) + + if result.message.tool_calls: + # get retrieval model config + function_call_name = result.message.tool_calls[0].function.name + dataset = db.session.query(Dataset).filter( + Dataset.id == function_call_name + ).first() + if dataset: + retrieval_model_config = dataset.retrieval_model \ + if dataset.retrieval_model else default_retrieval_model + + # get top k + top_k = retrieval_model_config['top_k'] + # get retrieval method + retrival_method = retrieval_model_config['search_method'] + # get reranking model + reranking_model = retrieval_model_config['reranking_model'] + # get score threshold + score_threshold = .0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, + query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + return results + + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ + ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.singleRetrievalConfig.model.name + provider_name = node_data.singleRetrievalConfig.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.singleRetrievalConfig.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.singleRetrievalConfig.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _multiple_retrieve(self, available_datasets, node_data, variables): + threads = [] + all_documents = [] + dataset_ids = [dataset.id for dataset in available_datasets] + for dataset in available_datasets: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset.id, + 'query': variables['#query#'], + 'top_k': node_data.multiple_retrieval_config.top_k, + 'all_documents': all_documents, + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=node_data.multiple_retrieval_config.reranking_model.provider, + model_type=ModelType.RERANK, + model=node_data.multiple_retrieval_config.reranking_model.name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(variables['#query#'], all_documents, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.top_k) + + return all_documents + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py new file mode 100644 index 00000000000000..67163c93cd2b19 --- /dev/null +++ b/api/core/workflow/nodes/llm/entities.py @@ -0,0 +1,51 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class ContextConfig(BaseModel): + """ + Context Config. + """ + enabled: bool + variable_selector: Optional[list[str]] = None + + +class VisionConfig(BaseModel): + """ + Vision Config. + """ + class Configs(BaseModel): + """ + Configs. + """ + detail: Literal['low', 'high'] + + enabled: bool + configs: Optional[Configs] = None + + +class LLMNodeData(BaseNodeData): + """ + LLM Node Data. + """ + model: ModelConfig + variables: list[VariableSelector] = [] + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] + memory: Optional[MemoryConfig] = None + context: ContextConfig + vision: VisionConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py new file mode 100644 index 00000000000000..cb5a33309141dd --- /dev/null +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -0,0 +1,485 @@ +from collections.abc import Generator +from typing import Optional, cast + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.llm.entities import LLMNodeData +from extensions.ext_database import db +from models.model import Conversation +from models.provider import Provider, ProviderType +from models.workflow import WorkflowNodeExecutionStatus + + +class LLMNode(BaseNode): + _node_data_cls = LLMNodeData + node_type = NodeType.LLM + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + node_inputs = None + process_data = None + + try: + # fetch variables and fetch values from variable pool + inputs = self._fetch_inputs(node_data, variable_pool) + + node_inputs = { + **inputs + } + + # fetch files + files: list[FileObj] = self._fetch_files(node_data, variable_pool) + + if files: + node_inputs['#files#'] = [{ + 'type': file.type.value, + 'transfer_method': file.transfer_method.value, + 'url': file.url, + 'upload_file_id': file.upload_file_id, + } for file in files] + + # fetch context value + context = self._fetch_context(node_data, variable_pool) + + if context: + node_inputs['#context#'] = context + + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + + # fetch memory + memory = self._fetch_memory(node_data, variable_pool, model_instance) + + # fetch prompt messages + prompt_messages, stop = self._fetch_prompt_messages( + node_data=node_data, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ) + } + + # handle invoke result + result_text, usage = self._invoke_llm( + node_data=node_data, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop + ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) + + outputs = { + 'text': result_text, + 'usage': jsonable_encoder(usage) + } + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) + + def _invoke_llm(self, node_data: LLMNodeData, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str]) -> tuple[str, LLMUsage]: + """ + Invoke large language model + :param node_data: node data + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data.model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + text, usage = self._handle_invoke_result( + invoke_result=invoke_result + ) + + # deduct quota + self._deduct_llm_quota(model_instance=model_instance, usage=usage) + + return text, usage + + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + model = None + prompt_messages = [] + full_text = '' + usage = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + return full_text, usage + + def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + inputs = {} + for variable_selector in node_data.variables: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.value_selector} not found') + + inputs[variable_selector.variable] = variable_value + + return inputs + + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]: + """ + Fetch files + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.vision.enabled: + return [] + + files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + if not files: + return [] + + return files + + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + """ + Fetch context + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.context.enabled: + return None + + context_value = variable_pool.get_variable_value(node_data.context.variable_selector) + if context_value: + if isinstance(context_value, str): + return context_value + elif isinstance(context_value, list): + context_str = '' + for item in context_value: + if 'content' not in item: + raise ValueError(f'Invalid context structure: {item}') + + context_str += item['content'] + '\n' + + return context_str.strip() + + return None + + def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.model.name + provider_name = node_data.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory(self, node_data: LLMNodeData, + variable_pool: VariablePool, + model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + """ + Fetch memory + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.memory: + return None + + # get conversation id + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) + if conversation_id is None: + return None + + # get conversation + conversation = db.session.query(Conversation).filter( + Conversation.tenant_id == self.tenant_id, + Conversation.app_id == self.app_id, + Conversation.id == conversation_id + ).first() + + if not conversation: + return None + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + return memory + + def _fetch_prompt_messages(self, node_data: LLMNodeData, + inputs: dict[str, str], + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Fetch prompt messages + :param node_data: node data + :param inputs: inputs + :param files: files + :param context: context + :param memory: memory + :param model_config: model config + :return: + """ + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=node_data.prompt_template, + inputs=inputs, + query='', + files=files, + context=context, + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + stop = model_config.stop + + return prompt_messages, stop + + def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if 'gpt-4' in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + if node_data.context.enabled: + variable_mapping['#context#'] = node_data.context.variable_selector + + if node_data.vision.enabled: + variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + + return variable_mapping + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + { + "role": "system", + "text": "You are a helpful AI assistant." + } + ] + }, + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" + }, + "stop": ["Human:"] + } + } + } + } diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py new file mode 100644 index 00000000000000..f676b6372ac3ec --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -0,0 +1,19 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class QuestionClassifierNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "question-classifier", + "config": { + "instructions": "" # TODO + } + } diff --git a/api/core/workflow/nodes/start/__init__.py b/api/core/workflow/nodes/start/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py new file mode 100644 index 00000000000000..0bd5f203bf72a5 --- /dev/null +++ b/api/core/workflow/nodes/start/entities.py @@ -0,0 +1,9 @@ +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class StartNodeData(BaseNodeData): + """ + Start Node Data + """ + variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py new file mode 100644 index 00000000000000..08171457fbb4dc --- /dev/null +++ b/api/core/workflow/nodes/start/start_node.py @@ -0,0 +1,78 @@ +from typing import cast + +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.start.entities import StartNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class StartNode(BaseNode): + _node_data_cls = StartNodeData + node_type = NodeType.START + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + variables = node_data.variables + + # Get cleaned inputs + cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=cleaned_inputs, + outputs=cleaned_inputs + ) + + def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"Input form variable {variable} is required") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/core/workflow/nodes/template_transform/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py new file mode 100644 index 00000000000000..d9099a8118498e --- /dev/null +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -0,0 +1,12 @@ + + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class TemplateTransformNodeData(BaseNodeData): + """ + Code Node Data. + """ + variables: list[VariableSelector] + template: str \ No newline at end of file diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py new file mode 100644 index 00000000000000..15d4b2a6e7b81f --- /dev/null +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -0,0 +1,90 @@ +from typing import Optional, cast + +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from models.workflow import WorkflowNodeExecutionStatus + +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = 1000 + +class TemplateTransformNode(BaseNode): + _node_data_cls = TemplateTransformNodeData + _node_type = NodeType.TEMPLATE_TRANSFORM + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "template-transform", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + } + ], + "template": "{{ arg1 }}" + } + } + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + """ + node_data = self.node_data + node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data) + + # Get variables + variables = {} + for variable_selector in node_data.variables: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + variables[variable] = value + # Run code + try: + result = CodeExecutor.execute_code( + language='jinja2', + code=node_data.template, + inputs=variables + ) + except CodeExecutionException as e: + return NodeRunResult( + inputs=variables, + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + return NodeRunResult( + inputs=variables, + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs={ + 'output': result['result'] + } + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/core/workflow/nodes/tool/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py new file mode 100644 index 00000000000000..7eb3cf655b1bdc --- /dev/null +++ b/api/core/workflow/nodes/tool/entities.py @@ -0,0 +1,39 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, validator + +from core.workflow.entities.base_node_data_entities import BaseNodeData + +ToolParameterValue = Union[str, int, float, bool] + +class ToolEntity(BaseModel): + provider_id: str + provider_type: Literal['builtin', 'api'] + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_configurations: dict[str, ToolParameterValue] + +class ToolNodeData(BaseNodeData, ToolEntity): + class ToolInput(BaseModel): + variable: str + variable_type: Literal['selector', 'static'] + value_selector: Optional[list[str]] + value: Optional[str] + + @validator('value') + def check_value(cls, value, values, **kwargs): + if values['variable_type'] == 'static' and value is None: + raise ValueError('value is required for static variable') + return value + + @validator('value_selector') + def check_value_selector(cls, value_selector, values, **kwargs): + if values['variable_type'] == 'selector' and value_selector is None: + raise ValueError('value_selector is required for selector variable') + return value_selector + + """ + Tool Node Schema + """ + tool_parameters: list[ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py new file mode 100644 index 00000000000000..d0bfd9e7973467 --- /dev/null +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -0,0 +1,147 @@ +from os import path +from typing import cast + +from core.file.file_obj import FileTransferMethod +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.tool.entities import ToolNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class ToolNode(BaseNode): + """ + Tool Node + """ + _node_data_cls = ToolNodeData + _node_type = NodeType.TOOL + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run the tool node + """ + + node_data = cast(ToolNodeData, self.node_data) + + # get parameters + parameters = self._generate_parameters(variable_pool, node_data) + # get tool runtime + try: + tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to get tool runtime: {str(e)}' + ) + + try: + messages = tool_runtime.invoke(self.user_id, parameters) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to invoke tool: {str(e)}', + ) + + # convert tool messages + plain_text, files = self._convert_tool_messages(messages) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'text': plain_text, + 'files': files + }, + inputs=parameters + ) + + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: + """ + Generate parameters + """ + return { + k.variable: + k.value if k.variable_type == 'static' else + variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else '' + for k in node_data.tool_parameters + } + + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + # extract plain text and files + files = self._extract_tool_response_binary(messages) + plain_text = self._extract_tool_response_text(messages) + + return plain_text, files + + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + """ + Extract tool response binary + """ + result = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + url = response.message + ext = path.splitext(url)[1] + mimetype = response.meta.get('mime_type', 'image/jpeg') + filename = response.save_as or url.split('/')[-1] + result.append({ + 'type': 'image', + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': url, + 'upload_file_id': None, + 'filename': filename, + 'file-ext': ext, + 'mime-type': mimetype, + }) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append({ + 'type': 'image', # TODO: only support image for now + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': response.message, + 'upload_file_id': None, + 'filename': response.save_as, + 'file-ext': path.splitext(response.save_as)[1], + 'mime-type': response.meta.get('mime_type', 'application/octet-stream'), + }) + elif response.type == ToolInvokeMessage.MessageType.LINK: + pass # TODO: + + return result + + def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: + """ + Extract tool response text + """ + return ''.join([ + f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else + f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' + for message in tool_response + ]) + + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + """ + return { + k.variable: k.value_selector + for k in node_data.tool_parameters + if k.variable_type == 'selector' + } diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py new file mode 100644 index 00000000000000..231a26a6613baf --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class VariableAssignerNode(BaseNode): + pass diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py new file mode 100644 index 00000000000000..a7379e6e99fceb --- /dev/null +++ b/api/core/workflow/workflow_engine_manager.py @@ -0,0 +1,490 @@ +import logging +import time +from typing import Optional + +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.base_node import BaseNode, UserFrom +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode +from extensions.ext_database import db +from models.workflow import ( + Workflow, + WorkflowNodeExecutionStatus, +) + +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, +} + +logger = logging.getLogger(__name__) + + +class WorkflowEngineManager: + def get_default_configs(self) -> list[dict]: + """ + Get default block configs + """ + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs + + def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_class = node_classes.get(node_type) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config + + def run_workflow(self, workflow: Workflow, + user_id: str, + user_from: UserFrom, + user_inputs: dict, + system_inputs: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Run workflow + :param workflow: Workflow instance + :param user_id: user id + :param user_from: user from + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks + :return: + """ + # fetch workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + if 'nodes' not in graph or 'edges' not in graph: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + + # init workflow run + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started() + + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow=workflow, + start_at=time.perf_counter(), + variable_pool=VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs + ), + user_id=user_id, + user_from=user_from + ) + + try: + predecessor_node = None + has_entry_node = False + while True: + # get next node, multiple target nodes in the future + next_node = self._get_next_node( + workflow_run_state=workflow_run_state, + graph=graph, + predecessor_node=predecessor_node, + callbacks=callbacks + ) + + if not next_node: + break + + has_entry_node = True + + # max steps 30 reached + if len(workflow_run_state.workflow_nodes_and_results) > 30: + raise ValueError('Max steps 30 reached.') + + # or max execution time 10min reached + if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600): + raise ValueError('Max execution time 10min reached.') + + # run workflow, run multiple target nodes in the future + self._run_workflow_node( + workflow_run_state=workflow_run_state, + node=next_node, + predecessor_node=predecessor_node, + callbacks=callbacks + ) + + if next_node.node_type in [NodeType.END, NodeType.ANSWER]: + break + + predecessor_node = next_node + + if not has_entry_node: + self._workflow_run_failed( + error='Start node not found in workflow graph.', + callbacks=callbacks + ) + return + except GenerateTaskStoppedException as e: + return + except Exception as e: + self._workflow_run_failed( + error=str(e), + callbacks=callbacks + ) + return + + # workflow run success + self._workflow_run_success( + callbacks=callbacks + ) + + def single_step_run_workflow_node(self, workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get('id') == node_id: + node_config = node + break + + if not node_config: + raise ValueError('node id not found in workflow graph') + + # Get node class + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + + # init workflow run state + node_instance = node_cls( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + config=node_config + ) + + try: + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={} + ) + + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) + except NotImplementedError: + variable_mapping = {} + + for variable_key, variable_selector in variable_mapping.items(): + if variable_key not in user_inputs: + raise ValueError(f'Variable key {variable_key} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + + # append variable and value to variable pool + variable_pool.append_variable( + node_id=variable_node_id, + variable_key_list=variable_key_list, + value=user_inputs.get(variable_key) + ) + + # run node + node_run_result = node_instance.run( + variable_pool=variable_pool + ) + except Exception as e: + raise WorkflowNodeRunFailedError( + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_title=node_instance.node_data.title, + error=str(e) + ) + + return node_instance, node_run_result + + + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow run success + :param callbacks: workflow callbacks + :return: + """ + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_succeeded() + + def _workflow_run_failed(self, error: str, + callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow run failed + :param error: error message + :param callbacks: workflow callbacks + :return: + """ + if callbacks: + for callback in callbacks: + callback.on_workflow_run_failed( + error=error + ) + + def _get_next_node(self, workflow_run_state: WorkflowRunState, + graph: dict, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: + """ + Get next node + multiple target nodes in the future. + :param graph: workflow graph + :param predecessor_node: predecessor node + :param callbacks: workflow callbacks + :return: + """ + nodes = graph.get('nodes') + if not nodes: + return None + + if not predecessor_node: + for node_config in nodes: + if node_config.get('data', {}).get('type', '') == NodeType.START.value: + return StartNode( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, + config=node_config, + callbacks=callbacks + ) + else: + edges = graph.get('edges') + source_node_id = predecessor_node.node_id + + # fetch all outgoing edges from source node + outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] + if not outgoing_edges: + return None + + # fetch target node id from outgoing edges + outgoing_edge = None + source_handle = predecessor_node.node_run_result.edge_source_handle + if source_handle: + for edge in outgoing_edges: + if edge.get('source_handle') and edge.get('source_handle') == source_handle: + outgoing_edge = edge + break + else: + outgoing_edge = outgoing_edges[0] + + if not outgoing_edge: + return None + + target_node_id = outgoing_edge.get('target') + + # fetch target node from target node id + target_node_config = None + for node in nodes: + if node.get('id') == target_node_id: + target_node_config = node + break + + if not target_node_config: + return None + + # get next node + target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) + + return target_node( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, + config=target_node_config, + callbacks=callbacks + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + return time.perf_counter() - start_at > max_execution_time + + def _run_workflow_node(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> None: + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_started( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None + ) + + db.session.close() + + workflow_nodes_and_result = WorkflowNodeAndResult( + node=node, + result=None + ) + + # add to workflow_nodes_and_results + workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) + + try: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool + ) + except Exception as e: + logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: + # node run failed + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_failed( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + error=node_run_result.error, + inputs=node_run_result.inputs, + process_data=node_run_result.process_data, + ) + + raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") + + workflow_nodes_and_result.result = node_run_result + + # node run success + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_succeeded( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + inputs=node_run_result.inputs, + process_data=node_run_result.process_data, + outputs=node_run_result.outputs, + execution_metadata=node_run_result.metadata + ) + + if node_run_result.outputs: + for variable_key, variable_value in node_run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + + if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) + + db.session.close() + + + def _append_variables_recursively(self, variable_pool: VariablePool, + node_id: str, + variable_key_list: list[str], + variable_value: VariableValue): + """ + Append variables recursively + :param variable_pool: variable pool + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + variable_pool.append_variable( + node_id=node_id, + variable_key_list=variable_key_list, + value=variable_value + ) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + variable_pool=variable_pool, + node_id=node_id, + variable_key_list=new_key_list, + variable_value=value + ) diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 88d226d3033ab1..fdfb401bd4d334 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -2,6 +2,7 @@ from .clean_when_document_deleted import handle from .create_document_index import handle from .create_installed_app_when_app_created import handle +from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle from .generate_conversation_name_when_first_message_created import handle diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py new file mode 100644 index 00000000000000..25fba591d02af8 --- /dev/null +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -0,0 +1,20 @@ +from events.app_event import app_was_created +from extensions.ext_database import db +from models.model import Site + + +@app_was_created.connect +def handle(sender, **kwargs): + """Create site record when an app is created.""" + app = sender + account = kwargs.get('account') + site = Site( + app_id=app.id, + title=app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8c335f201fb590..53cbb2ecdc96ce 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,9 +8,12 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get('application_generate_entity') - model_config = application_generate_entity.app_orchestration_config_entity.model_config + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return + + model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle provider_configuration = provider_model_bundle.configuration @@ -43,7 +46,7 @@ def handle(sender, **kwargs): if used_quota is not None: db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.tenant_id, + Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 74dc8d5112fa5c..31535bf4ef68fb 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,6 +1,7 @@ -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db +from models.model import AppMode @message_was_created.connect @@ -15,7 +16,7 @@ def handle(sender, **kwargs): auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) if auto_generate_conversation_name and is_first_message: - if conversation.mode == 'chat': + if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app if not app_model: return diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 69b3a90e441659..ae983cc5d1a537 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,10 +9,13 @@ @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get('application_generate_entity') + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.tenant_id, - Provider.provider_name == application_generate_entity.app_orchestration_config_entity.model_config.provider + Provider.tenant_id == application_generate_entity.app_config.tenant_id, + Provider.provider_name == application_generate_entity.model_config.provider ).update({'last_used': datetime.utcnow()}) db.session.commit() diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 5974de34de201b..c77808447519fc 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,20 +2,13 @@ from libs.helper import TimestampField -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - - annotation_fields = { "id": fields.String, "question": fields.String, "answer": fields.Raw(attribute='content'), "hit_count": fields.Integer, "created_at": TimestampField, - # 'account': fields.Nested(account_fields, allow_null=True) + # 'account': fields.Nested(simple_account_fields, allow_null=True) } annotation_list_fields = { diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index e6c1272086581d..ccb95ad5731147 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -5,6 +5,7 @@ app_detail_kernel_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -41,16 +42,13 @@ app_detail_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, - 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_demo': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), 'created_at': TimestampField } @@ -66,14 +64,11 @@ app_partial_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, - 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'is_demo': fields.Boolean, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), 'created_at': TimestampField } @@ -117,16 +112,13 @@ app_detail_fields_with_site = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_agent': fields.Boolean, - 'is_demo': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), 'site': fields.Nested(site_fields), 'api_base_url': fields.String, 'created_at': TimestampField, diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1adc836aa2479f..747b0b86abf3ef 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,5 +1,6 @@ from flask_restful import fields +from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -8,31 +9,25 @@ def format(self, value): return value[0]['text'] if value else '' -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - feedback_fields = { 'rating': fields.String, 'content': fields.String, 'from_source': fields.String, 'from_end_user_id': fields.String, - 'from_account': fields.Nested(account_fields, allow_null=True), + 'from_account': fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { 'id': fields.String, 'question': fields.String, 'content': fields.String, - 'account': fields.Nested(account_fields, allow_null=True), + 'account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } annotation_hit_history_fields = { 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(account_fields, allow_null=True), + 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } @@ -71,6 +66,7 @@ def format(self, value): 'from_end_user_id': fields.String, 'from_account_id': fields.String, 'feedbacks': fields.List(fields.Nested(feedback_fields)), + 'workflow_run_id': fields.String, 'annotation': fields.Nested(annotation_fields, allow_null=True), 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), 'created_at': TimestampField, diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py new file mode 100644 index 00000000000000..ee630c12c2e9aa --- /dev/null +++ b/api/fields/end_user_fields.py @@ -0,0 +1,8 @@ +from flask_restful import fields + +simple_end_user_fields = { + 'id': fields.String, + 'type': fields.String, + 'is_anonymous': fields.Boolean, + 'session_id': fields.String, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 821d3c0adef3ab..35cc5a64755eca 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -17,8 +17,7 @@ 'is_pinned': fields.Boolean, 'last_used_at': TimestampField, 'editable': fields.Boolean, - 'uninstallable': fields.Boolean, - 'is_agent': fields.Boolean, + 'uninstallable': fields.Boolean } installed_app_list_fields = { diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py new file mode 100644 index 00000000000000..79164b3848536d --- /dev/null +++ b/api/fields/member_fields.py @@ -0,0 +1,38 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +simple_account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'is_password_set': fields.Boolean, + 'interface_language': fields.String, + 'interface_theme': fields.String, + 'timezone': fields.String, + 'last_login_at': TimestampField, + 'last_login_ip': fields.String, + 'created_at': TimestampField +} + +account_with_role_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'last_login_at': TimestampField, + 'created_at': TimestampField, + 'role': fields.String, + 'status': fields.String, +} + +account_with_role_list_fields = { + 'accounts': fields.List(fields.Nested(account_with_role_fields)) +} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py new file mode 100644 index 00000000000000..e230c159fba59a --- /dev/null +++ b/api/fields/workflow_app_log_fields.py @@ -0,0 +1,24 @@ +from flask_restful import fields + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from fields.workflow_run_fields import workflow_run_for_log_fields +from libs.helper import TimestampField + +workflow_app_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), + "created_from": fields.String, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_at": TimestampField +} + +workflow_app_log_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py new file mode 100644 index 00000000000000..9919a440e8d91e --- /dev/null +++ b/api/fields/workflow_fields.py @@ -0,0 +1,14 @@ +from flask_restful import fields + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +workflow_fields = { + 'id': fields.String, + 'graph': fields.Raw(attribute='graph_dict'), + 'features': fields.Raw(attribute='features_dict'), + 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), + 'created_at': TimestampField, + 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), + 'updated_at': TimestampField +} diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py new file mode 100644 index 00000000000000..72510cd27ac621 --- /dev/null +++ b/api/fields/workflow_run_fields.py @@ -0,0 +1,80 @@ +from flask_restful import fields + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +workflow_run_for_log_fields = { + "id": fields.String, + "version": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_for_list_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "status": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_pagination_fields = { + 'limit': fields.Integer(attribute='limit'), + 'has_more': fields.Boolean(attribute='has_more'), + 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') +} + +workflow_run_detail_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "graph": fields.Raw(attribute='graph_dict'), + "inputs": fields.Raw(attribute='inputs_dict'), + "status": fields.String, + "outputs": fields.Raw(attribute='outputs_dict'), + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_steps": fields.Integer, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_node_execution_fields = { + "id": fields.String, + "index": fields.Integer, + "predecessor_node_id": fields.String, + "node_id": fields.String, + "node_type": fields.String, + "title": fields.String, + "inputs": fields.Raw(attribute='inputs_dict'), + "process_data": fields.Raw(attribute='process_data_dict'), + "outputs": fields.Raw(attribute='outputs_dict'), + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), + "created_at": TimestampField, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "finished_at": TimestampField +} + +workflow_run_node_execution_list_fields = { + 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), +} diff --git a/api/libs/helper.py b/api/libs/helper.py index a35f4ad4710868..3eb14c50f049e3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -15,7 +15,7 @@ def run(script): class TimestampField(fields.Raw): - def format(self, value): + def format(self, value) -> int: return int(value.timestamp()) diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py new file mode 100644 index 00000000000000..f388b99b9068a0 --- /dev/null +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -0,0 +1,48 @@ +"""conversation columns set nullable + +Revision ID: 42e85ed5564d +Revises: f9107f83abab +Create Date: 2024-03-07 08:30:29.133614 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '42e85ed5564d' +down_revision = 'f9107f83abab' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py new file mode 100644 index 00000000000000..299f442de989be --- /dev/null +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -0,0 +1,35 @@ +"""enable tool file without conversation id + +Revision ID: 563cf8bf777b +Revises: b5429b71023c +Create Date: 2024-03-14 04:54:56.679506 + +""" +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '563cf8bf777b' +down_revision = 'b5429b71023c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py new file mode 100644 index 00000000000000..8fadf2dc6c98c7 --- /dev/null +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -0,0 +1,142 @@ +"""add workflow + +Revision ID: b289e2408ee2 +Revises: 16830a790f0f +Create Date: 2024-02-19 12:47:24.646954 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b289e2408ee2' +down_revision = '16830a790f0f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_app_logs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=False), + sa.Column('created_from', sa.String(length=255), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_app_log_pkey') + ) + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.create_index('workflow_app_log_app_idx', ['tenant_id', 'app_id'], unique=False) + + op.create_table('workflow_node_executions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=True), + sa.Column('index', sa.Integer(), nullable=False), + sa.Column('predecessor_node_id', sa.String(length=255), nullable=True), + sa.Column('node_id', sa.String(length=255), nullable=False), + sa.Column('node_type', sa.String(length=255), nullable=False), + sa.Column('title', sa.String(length=255), nullable=False), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('process_data', sa.Text(), nullable=True), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('execution_metadata', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey') + ) + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.create_index('workflow_node_execution_node_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'], unique=False) + batch_op.create_index('workflow_node_execution_workflow_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'workflow_run_id'], unique=False) + + op.create_table('workflow_runs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('sequence_number', sa.Integer(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_run_pkey') + ) + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'triggered_from'], unique=False) + + op.create_table('workflows', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('features', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', postgresql.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_pkey') + ) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_run_id', postgresql.UUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('workflow_run_id') + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('workflow_id') + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_index('workflow_version_idx') + + op.drop_table('workflows') + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_triggerd_from_idx') + + op.drop_table('workflow_runs') + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_workflow_run_idx') + batch_op.drop_index('workflow_node_execution_node_run_idx') + + op.drop_table('workflow_node_executions') + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_app_log_app_idx') + + op.drop_table('workflow_app_logs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py new file mode 100644 index 00000000000000..ee81fdab2872a2 --- /dev/null +++ b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py @@ -0,0 +1,41 @@ +"""messages columns set nullable + +Revision ID: b5429b71023c +Revises: 42e85ed5564d +Create Date: 2024-03-07 09:52:00.846136 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b5429b71023c' +down_revision = '42e85ed5564d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py new file mode 100644 index 00000000000000..aefbe43f148f26 --- /dev/null +++ b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py @@ -0,0 +1,70 @@ +"""set model config column nullable + +Revision ID: cc04d0998d4d +Revises: b289e2408ee2 +Create Date: 2024-02-27 03:47:47.376325 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'cc04d0998d4d' +down_revision = 'b289e2408ee2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('configs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.alter_column('api_rpm', + existing_type=sa.Integer(), + server_default='0', + nullable=False) + + batch_op.alter_column('api_rph', + existing_type=sa.Integer(), + server_default='0', + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.alter_column('api_rpm', + existing_type=sa.Integer(), + server_default=None, + nullable=False) + + batch_op.alter_column('api_rph', + existing_type=sa.Integer(), + server_default=None, + nullable=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.alter_column('configs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py new file mode 100644 index 00000000000000..3e5ae0d67d7e58 --- /dev/null +++ b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py @@ -0,0 +1,31 @@ +"""add desc for apps + +Revision ID: f9107f83abab +Revises: cc04d0998d4d +Create Date: 2024-02-28 08:16:14.090481 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'f9107f83abab' +down_revision = 'cc04d0998d4d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('description') + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 44d37d3052e8cd..47eec535428105 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1 +1,44 @@ -# -*- coding:utf-8 -*- \ No newline at end of file +from enum import Enum + + +class CreatedByRole(Enum): + """ + Enum class for createdByRole + """ + ACCOUNT = "account" + END_USER = "end_user" + + @classmethod + def value_of(cls, value: str) -> 'CreatedByRole': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for role in cls: + if role.value == value: + return role + raise ValueError(f'invalid createdByRole value {value}') + + +class CreatedFrom(Enum): + """ + Enum class for createdFrom + """ + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + + @classmethod + def value_of(cls, value: str) -> 'CreatedFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for role in cls: + if role.value == value: + return role + raise ValueError(f'invalid createdFrom value {value}') diff --git a/api/models/model.py b/api/models/model.py index 8776f896730a07..5a7311a0c72ecc 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,5 +1,7 @@ import json import uuid +from enum import Enum +from typing import Optional from flask import current_app, request from flask_login import UserMixin @@ -24,6 +26,28 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) +class AppMode(Enum): + COMPLETION = 'completion' + WORKFLOW = 'workflow' + CHAT = 'chat' + ADVANCED_CHAT = 'advanced-chat' + AGENT_CHAT = 'agent-chat' + CHANNEL = 'channel' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class App(db.Model): __tablename__ = 'apps' __table_args__ = ( @@ -34,15 +58,17 @@ class App(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(UUID, nullable=False) name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(UUID, nullable=True) + workflow_id = db.Column(UUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False) - api_rph = db.Column(db.Integer, nullable=False) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -55,10 +81,19 @@ def site(self): return site @property - def app_model_config(self): - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() - return app_model_config + def app_model_config(self) -> Optional['AppModelConfig']: + if self.app_model_config_id: + return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + + return None + + @property + def workflow(self): + if self.workflow_id: + from api.models.workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None @property def api_base_url(self): @@ -69,7 +104,7 @@ def api_base_url(self): def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant - + @property def is_agent(self) -> bool: app_model_config = self.app_model_config @@ -78,10 +113,10 @@ def is_agent(self) -> bool: if not app_model_config.agent_mode: return False if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: return True return False - + @property def deleted_tools(self) -> list: # get agent mode tools @@ -129,6 +164,7 @@ def deleted_tools(self) -> list: return deleted_tools + class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( @@ -138,9 +174,9 @@ class AppModelConfig(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) - configs = db.Column(db.JSON, nullable=False) + provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) + configs = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) opening_statement = db.Column(db.Text) @@ -156,7 +192,7 @@ class AppModelConfig(db.Model): agent_mode = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text) retriever_resource = db.Column(db.Text) - prompt_type = db.Column(db.String(255), nullable=False, default='simple') + prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) chat_prompt_config = db.Column(db.Text) completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) @@ -263,9 +299,6 @@ def file_upload_dict(self) -> dict: def to_dict(self) -> dict: return { - "provider": "", - "model_id": "", - "configs": {}, "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, @@ -289,26 +322,29 @@ def to_dict(self) -> dict: } def from_model_config_dict(self, model_config: dict): - self.provider = "" - self.model_id = "" - self.configs = {} - self.opening_statement = model_config['opening_statement'] - self.suggested_questions = json.dumps(model_config['suggested_questions']) - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) + self.opening_statement = model_config.get('opening_statement') + self.suggested_questions = json.dumps(model_config['suggested_questions']) \ + if model_config.get('suggested_questions') else None + self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ + if model_config.get('suggested_questions_after_answer') else None self.speech_to_text = json.dumps(model_config['speech_to_text']) \ if model_config.get('speech_to_text') else None self.text_to_speech = json.dumps(model_config['text_to_speech']) \ if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) + self.more_like_this = json.dumps(model_config['more_like_this']) \ + if model_config.get('more_like_this') else None self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ if model_config.get('sensitive_word_avoidance') else None self.external_data_tools = json.dumps(model_config['external_data_tools']) \ if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) - self.user_input_form = json.dumps(model_config['user_input_form']) + self.model = json.dumps(model_config['model']) \ + if model_config.get('model') else None + self.user_input_form = json.dumps(model_config['user_input_form']) \ + if model_config.get('user_input_form') else None self.dataset_query_variable = model_config.get('dataset_query_variable') self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) + self.agent_mode = json.dumps(model_config['agent_mode']) \ + if model_config.get('agent_mode') else None self.retriever_resource = json.dumps(model_config['retriever_resource']) \ if model_config.get('retriever_resource') else None self.prompt_type = model_config.get('prompt_type', 'simple') @@ -326,9 +362,6 @@ def copy(self): new_app_model_config = AppModelConfig( id=self.id, app_id=self.app_id, - provider="", - model_id="", - configs={}, opening_statement=self.opening_statement, suggested_questions=self.suggested_questions, suggested_questions_after_answer=self.suggested_questions_after_answer, @@ -408,12 +441,6 @@ def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant - @property - def is_agent(self) -> bool: - app = self.app - if not app: - return False - return app.is_agent class Conversation(db.Model): __tablename__ = 'conversations' @@ -424,10 +451,10 @@ class Conversation(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - app_model_config_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) + app_model_config_id = db.Column(UUID, nullable=True) + model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - model_id = db.Column(db.String(255), nullable=False) + model_id = db.Column(db.String(255), nullable=True) mode = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) @@ -558,8 +585,8 @@ class Message(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) + model_provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) inputs = db.Column(db.JSON) @@ -581,6 +608,7 @@ class Message(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + workflow_run_id = db.Column(UUID) @property def user_feedback(self): @@ -679,6 +707,14 @@ def files(self): return files + @property + def workflow_run(self): + if self.workflow_run_id: + from api.models.workflow import WorkflowRun + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + + return None + class MessageFeedback(db.Model): __tablename__ = 'message_feedbacks' diff --git a/api/models/tools.py b/api/models/tools.py index bceef7a8290151..4bdf2503ce0619 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -218,7 +218,7 @@ class ToolFile(db.Model): # tenant id tenant_id = db.Column(UUID, nullable=False) # conversation id - conversation_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, nullable=True) # file key file_key = db.Column(db.String(255), nullable=False) # mime type diff --git a/api/models/workflow.py b/api/models/workflow.py new file mode 100644 index 00000000000000..9c5b2a0b8f3135 --- /dev/null +++ b/api/models/workflow.py @@ -0,0 +1,519 @@ +import json +from enum import Enum +from typing import Union + +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db +from models.account import Account + + +class CreatedByRole(Enum): + """ + Created By Role Enum + """ + ACCOUNT = 'account' + END_USER = 'end_user' + + @classmethod + def value_of(cls, value: str) -> 'CreatedByRole': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid created by role value {value}') + + +class WorkflowType(Enum): + """ + Workflow Type Enum + """ + WORKFLOW = 'workflow' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow type value {value}') + + @classmethod + def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': + """ + Get workflow type from app mode. + + :param app_mode: app mode + :return: workflow type + """ + from models.model import AppMode + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) + return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT + + +class Workflow(db.Model): + """ + Workflow, for `Workflow App` and `Chat App workflow mode`. + + Attributes: + + - id (uuid) Workflow ID, pk + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - type (string) Workflow type + + `workflow` for `Workflow App` + + `chat` for `Chat App workflow mode` + + - version (string) Version + + `draft` for draft version (only one for each app), other for version number (redundant) + + - graph (text) Workflow canvas configuration (JSON) + + The entire canvas configuration JSON, including Node, Edge, and other configurations + + - nodes (array[object]) Node list, see Node Schema + + - edges (array[object]) Edge list, see Edge Schema + + - created_by (uuid) Creator ID + - created_at (timestamp) Creation time + - updated_by (uuid) `optional` Last updater ID + - updated_at (timestamp) `optional` Last update time + """ + + __tablename__ = 'workflows' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_pkey'), + db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + version = db.Column(db.String(255), nullable=False) + graph = db.Column(db.Text) + features = db.Column(db.Text) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(UUID) + updated_at = db.Column(db.DateTime) + + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + + @property + def graph_dict(self): + return json.loads(self.graph) if self.graph else None + + @property + def features_dict(self): + return json.loads(self.features) if self.features else None + + def user_input_form(self) -> list: + # get start node from graph + if not self.graph: + return [] + + graph_dict = self.graph_dict + if 'nodes' not in graph_dict: + return [] + + start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) + if not start_node: + return [] + + # get user_input_form from start node + return start_node.get('data', {}).get('variables', []) + + +class WorkflowRunTriggeredFrom(Enum): + """ + Workflow Run Triggered From Enum + """ + DEBUGGING = 'debugging' + APP_RUN = 'app-run' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow run triggered from value {value}') + + +class WorkflowRunStatus(Enum): + """ + Workflow Run Status Enum + """ + RUNNING = 'running' + SUCCEEDED = 'succeeded' + FAILED = 'failed' + STOPPED = 'stopped' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowRunStatus': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow run status value {value}') + + +class WorkflowRun(db.Model): + """ + Workflow Run + + Attributes: + + - id (uuid) Run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID + - type (string) Workflow type + - triggered_from (string) Trigger source + + `debugging` for canvas debugging + + `app-run` for (published) app execution + + - version (string) Version + - graph (text) Workflow canvas configuration (JSON) + - inputs (text) Input parameters + - status (string) Execution status, `running` / `succeeded` / `failed` / `stopped` + - outputs (text) `optional` Output content + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - total_tokens (int) `optional` Total tokens used + - total_steps (int) Total steps (redundant), default 0 + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Runner ID + - created_at (timestamp) Run time + - finished_at (timestamp) End time + """ + + __tablename__ = 'workflow_runs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), + db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + sequence_number = db.Column(db.Integer, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + triggered_from = db.Column(db.String(255), nullable=False) + version = db.Column(db.String(255), nullable=False) + graph = db.Column(db.Text) + inputs = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) + outputs = db.Column(db.Text) + error = db.Column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + total_steps = db.Column(db.Integer, server_default=db.text('0')) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + finished_at = db.Column(db.DateTime) + + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None + + @property + def graph_dict(self): + return json.loads(self.graph) if self.graph else None + + @property + def inputs_dict(self): + return json.loads(self.inputs) if self.inputs else None + + @property + def outputs_dict(self): + return json.loads(self.outputs) if self.outputs else None + + +class WorkflowNodeExecutionTriggeredFrom(Enum): + """ + Workflow Node Execution Triggered From Enum + """ + SINGLE_STEP = 'single-step' + WORKFLOW_RUN = 'workflow-run' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow node execution triggered from value {value}') + + +class WorkflowNodeExecutionStatus(Enum): + """ + Workflow Node Execution Status Enum + """ + RUNNING = 'running' + SUCCEEDED = 'succeeded' + FAILED = 'failed' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow node execution status value {value}') + + +class WorkflowNodeExecution(db.Model): + """ + Workflow Node Execution + + - id (uuid) Execution ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Workflow ID + - triggered_from (string) Trigger source + + `single-step` for single-step debugging + + `workflow-run` for workflow execution (debugging / user execution) + + - workflow_run_id (uuid) `optional` Workflow run ID + + Null for single-step debugging. + + - index (int) Execution sequence number, used for displaying Tracing Node order + - predecessor_node_id (string) `optional` Predecessor node ID, used for displaying execution path + - node_id (string) Node ID + - node_type (string) Node type, such as `start` + - title (string) Node title + - inputs (json) All predecessor node variable content used in the node + - process_data (json) Node process data + - outputs (json) `optional` Node output variables + - status (string) Execution status, `running` / `succeeded` / `failed` + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - execution_metadata (text) Metadata + + - total_tokens (int) `optional` Total tokens used + + - total_price (decimal) `optional` Total cost + + - currency (string) `optional` Currency, such as USD / RMB + + - created_at (timestamp) Run time + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Runner ID + - finished_at (timestamp) End time + """ + + __tablename__ = 'workflow_node_executions' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), + db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'workflow_run_id'), + db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'node_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + triggered_from = db.Column(db.String(255), nullable=False) + workflow_run_id = db.Column(UUID) + index = db.Column(db.Integer, nullable=False) + predecessor_node_id = db.Column(db.String(255)) + node_id = db.Column(db.String(255), nullable=False) + node_type = db.Column(db.String(255), nullable=False) + title = db.Column(db.String(255), nullable=False) + inputs = db.Column(db.Text) + process_data = db.Column(db.Text) + outputs = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) + error = db.Column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + execution_metadata = db.Column(db.Text) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + finished_at = db.Column(db.DateTime) + + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None + + @property + def inputs_dict(self): + return json.loads(self.inputs) if self.inputs else None + + @property + def outputs_dict(self): + return json.loads(self.outputs) if self.outputs else None + + @property + def process_data_dict(self): + return json.loads(self.process_data) if self.process_data else None + + @property + def execution_metadata_dict(self): + return json.loads(self.execution_metadata) if self.execution_metadata else None + + +class WorkflowAppLogCreatedFrom(Enum): + """ + Workflow App Log Created From Enum + """ + SERVICE_API = 'service-api' + WEB_APP = 'web-app' + INSTALLED_APP = 'installed-app' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow app log created from value {value}') + + +class WorkflowAppLog(db.Model): + """ + Workflow App execution log, excluding workflow debugging records. + + Attributes: + + - id (uuid) run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Associated Workflow ID + - workflow_run_id (uuid) Associated Workflow Run ID + - created_from (string) Creation source + + `service-api` App Execution OpenAPI + + `web-app` WebApp + + `installed-app` Installed App + + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Creator ID, depends on the user table according to created_by_role + - created_at (timestamp) Creation time + """ + + __tablename__ = 'workflow_app_logs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), + db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + workflow_run_id = db.Column(UUID, nullable=False) + created_from = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def workflow_run(self): + return WorkflowRun.query.get(self.workflow_run_id) + + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d52f6e20c219a8..213df262223d8a 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,7 @@ import copy -from core.prompt.advanced_prompt_templates import ( +from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, @@ -13,7 +13,7 @@ COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) -from core.prompt.prompt_transform import AppMode +from models.model import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 2e21e562665938..c84f6fbf454daf 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,527 +1,18 @@ -import re -import uuid - -from core.entities.agent_entities import PlanningStrategy -from core.external_data_tool.factory import ExternalDataToolFactory -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers import model_provider_factory -from core.moderation.factory import ModerationFactory -from core.prompt.prompt_transform import AppMode -from core.provider_manager import ProviderManager -from models.account import Account -from services.dataset_service import DatasetService - -SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from models.model import AppMode class AppModelConfigService: - @classmethod - def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool: - # verify if the dataset ID exists - dataset = DatasetService.get_dataset(dataset_id) - - if not dataset: - return False - - if dataset.tenant_id != account.current_tenant_id: - return False - - return True @classmethod - def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict: - # 6. model.completion_params - if not isinstance(cp, dict): - raise ValueError("model.completion_params must be of object type") - - # stop - if 'stop' not in cp: - cp["stop"] = [] - elif not isinstance(cp["stop"], list): - raise ValueError("stop in model.completion_params must be of list type") - - if len(cp["stop"]) > 4: - raise ValueError("stop sequences must be less than 4") - - return cp - - @classmethod - def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict: - # opening_statement - if 'opening_statement' not in config or not config["opening_statement"]: - config["opening_statement"] = "" - - if not isinstance(config["opening_statement"], str): - raise ValueError("opening_statement must be of string type") - - # suggested_questions - if 'suggested_questions' not in config or not config["suggested_questions"]: - config["suggested_questions"] = [] - - if not isinstance(config["suggested_questions"], list): - raise ValueError("suggested_questions must be of list type") - - for question in config["suggested_questions"]: - if not isinstance(question, str): - raise ValueError("Elements in suggested_questions list must be of string type") - - # suggested_questions_after_answer - if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]: - config["suggested_questions_after_answer"] = { - "enabled": False - } - - if not isinstance(config["suggested_questions_after_answer"], dict): - raise ValueError("suggested_questions_after_answer must be of dict type") - - if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: - config["suggested_questions_after_answer"]["enabled"] = False - - if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): - raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") - - # speech_to_text - if 'speech_to_text' not in config or not config["speech_to_text"]: - config["speech_to_text"] = { - "enabled": False - } - - if not isinstance(config["speech_to_text"], dict): - raise ValueError("speech_to_text must be of dict type") - - if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: - config["speech_to_text"]["enabled"] = False - - if not isinstance(config["speech_to_text"]["enabled"], bool): - raise ValueError("enabled in speech_to_text must be of boolean type") - - # text_to_speech - if 'text_to_speech' not in config or not config["text_to_speech"]: - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } - - if not isinstance(config["text_to_speech"], dict): - raise ValueError("text_to_speech must be of dict type") - - if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: - config["text_to_speech"]["enabled"] = False - config["text_to_speech"]["voice"] = "" - config["text_to_speech"]["language"] = "" - - if not isinstance(config["text_to_speech"]["enabled"], bool): - raise ValueError("enabled in text_to_speech must be of boolean type") - - # return retriever resource - if 'retriever_resource' not in config or not config["retriever_resource"]: - config["retriever_resource"] = { - "enabled": False - } - - if not isinstance(config["retriever_resource"], dict): - raise ValueError("retriever_resource must be of dict type") - - if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: - config["retriever_resource"]["enabled"] = False - - if not isinstance(config["retriever_resource"]["enabled"], bool): - raise ValueError("enabled in retriever_resource must be of boolean type") - - # more_like_this - if 'more_like_this' not in config or not config["more_like_this"]: - config["more_like_this"] = { - "enabled": False - } - - if not isinstance(config["more_like_this"], dict): - raise ValueError("more_like_this must be of dict type") - - if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: - config["more_like_this"]["enabled"] = False - - if not isinstance(config["more_like_this"]["enabled"], bool): - raise ValueError("enabled in more_like_this must be of boolean type") - - # model - if 'model' not in config: - raise ValueError("model is required") - - if not isinstance(config["model"], dict): - raise ValueError("model must be of object type") - - # model.provider - provider_entities = model_provider_factory.get_providers() - model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: - raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") - - # model.name - if 'name' not in config["model"]: - raise ValueError("model.name is required") - - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM - ) - if not models: - raise ValueError("model.name must be in the specified model list") - - model_ids = [m.model for m in models] - if config["model"]["name"] not in model_ids: - raise ValueError("model.name must be in the specified model list") - - model_mode = None - for model in models: - if model.model == config["model"]["name"]: - model_mode = model.model_properties.get(ModelPropertyKey.MODE) - break - - # model.mode - if model_mode: - config['model']["mode"] = model_mode + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + if app_mode == AppMode.CHAT: + return ChatAppConfigManager.config_validate(tenant_id, config) + elif app_mode == AppMode.AGENT_CHAT: + return AgentChatAppConfigManager.config_validate(tenant_id, config) + elif app_mode == AppMode.COMPLETION: + return CompletionAppConfigManager.config_validate(tenant_id, config) else: - config['model']["mode"] = "completion" - - # model.completion_params - if 'completion_params' not in config["model"]: - raise ValueError("model.completion_params is required") - - config["model"]["completion_params"] = cls.validate_model_completion_params( - config["model"]["completion_params"], - config["model"]["name"] - ) - - # user_input_form - if "user_input_form" not in config or not config["user_input_form"]: - config["user_input_form"] = [] - - if not isinstance(config["user_input_form"], list): - raise ValueError("user_input_form must be a list of objects") - - variables = [] - for item in config["user_input_form"]: - key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "external_data_tool"]: - raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") - - form_item = item[key] - if 'label' not in form_item: - raise ValueError("label is required in user_input_form") - - if not isinstance(form_item["label"], str): - raise ValueError("label in user_input_form must be of string type") - - if 'variable' not in form_item: - raise ValueError("variable is required in user_input_form") - - if not isinstance(form_item["variable"], str): - raise ValueError("variable in user_input_form must be of string type") - - pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") - if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") - - variables.append(form_item["variable"]) - - if 'required' not in form_item or not form_item["required"]: - form_item["required"] = False - - if not isinstance(form_item["required"], bool): - raise ValueError("required in user_input_form must be of boolean type") - - if key == "select": - if 'options' not in form_item or not form_item["options"]: - form_item["options"] = [] - - if not isinstance(form_item["options"], list): - raise ValueError("options in user_input_form must be a list of strings") - - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: - raise ValueError("default value in user_input_form must be in the options list") - - # pre_prompt - if "pre_prompt" not in config or not config["pre_prompt"]: - config["pre_prompt"] = "" - - if not isinstance(config["pre_prompt"], str): - raise ValueError("pre_prompt must be of string type") - - # agent_mode - if "agent_mode" not in config or not config["agent_mode"]: - config["agent_mode"] = { - "enabled": False, - "tools": [] - } - - if not isinstance(config["agent_mode"], dict): - raise ValueError("agent_mode must be of object type") - - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False - - if not isinstance(config["agent_mode"]["enabled"], bool): - raise ValueError("enabled in agent_mode must be of boolean type") - - if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]: - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - - if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: - raise ValueError("strategy in agent_mode must be in the specified strategy list") - - if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]: - config["agent_mode"]["tools"] = [] - - if not isinstance(config["agent_mode"]["tools"], list): - raise ValueError("tools in agent_mode must be a list of objects") - - for tool in config["agent_mode"]["tools"]: - key = list(tool.keys())[0] - if key in SUPPORT_TOOLS: - # old style, use tool name as key - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False - - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") - - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") - - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not cls.is_dataset_exists(account, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") - else: - # latest style, use key-value pair - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - if "provider_type" not in tool: - raise ValueError("provider_type is required in agent_mode.tools") - if "provider_id" not in tool: - raise ValueError("provider_id is required in agent_mode.tools") - if "tool_name" not in tool: - raise ValueError("tool_name is required in agent_mode.tools") - if "tool_parameters" not in tool: - raise ValueError("tool_parameters is required in agent_mode.tools") - - # dataset_query_variable - cls.is_dataset_query_variable_valid(config, app_mode) - - # advanced prompt validation - cls.is_advanced_prompt_valid(config, app_mode) - - # external data tools validation - cls.is_external_data_tools_valid(tenant_id, config) - - # moderation validation - cls.is_moderation_valid(tenant_id, config) - - # file upload validation - cls.is_file_upload_valid(config) - - # Filter out extra parameters - filtered_config = { - "opening_statement": config["opening_statement"], - "suggested_questions": config["suggested_questions"], - "suggested_questions_after_answer": config["suggested_questions_after_answer"], - "speech_to_text": config["speech_to_text"], - "text_to_speech": config["text_to_speech"], - "retriever_resource": config["retriever_resource"], - "more_like_this": config["more_like_this"], - "sensitive_word_avoidance": config["sensitive_word_avoidance"], - "external_data_tools": config["external_data_tools"], - "model": { - "provider": config["model"]["provider"], - "name": config["model"]["name"], - "mode": config['model']["mode"], - "completion_params": config["model"]["completion_params"] - }, - "user_input_form": config["user_input_form"], - "dataset_query_variable": config.get('dataset_query_variable'), - "pre_prompt": config["pre_prompt"], - "agent_mode": config["agent_mode"], - "prompt_type": config["prompt_type"], - "chat_prompt_config": config["chat_prompt_config"], - "completion_prompt_config": config["completion_prompt_config"], - "dataset_configs": config["dataset_configs"], - "file_upload": config["file_upload"] - } - - return filtered_config - - @classmethod - def is_moderation_valid(cls, tenant_id: str, config: dict): - if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: - config["sensitive_word_avoidance"] = { - "enabled": False - } - - if not isinstance(config["sensitive_word_avoidance"], dict): - raise ValueError("sensitive_word_avoidance must be of dict type") - - if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: - config["sensitive_word_avoidance"]["enabled"] = False - - if not config["sensitive_word_avoidance"]["enabled"]: - return - - if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: - raise ValueError("sensitive_word_avoidance.type is required") - - type = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] - - ModerationFactory.validate_config( - name=type, - tenant_id=tenant_id, - config=config - ) - - @classmethod - def is_file_upload_valid(cls, config: dict): - if 'file_upload' not in config or not config["file_upload"]: - config["file_upload"] = {} - - if not isinstance(config["file_upload"], dict): - raise ValueError("file_upload must be of dict type") - - # check image config - if 'image' not in config["file_upload"] or not config["file_upload"]["image"]: - config["file_upload"]["image"] = {"enabled": False} - - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] - if number_limits < 1 or number_limits > 6: - raise ValueError("number_limits must be in [1, 6]") - - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: - raise ValueError("detail must be in ['high', 'low']") - - transfer_methods = config['file_upload']['image']['transfer_methods'] - if not isinstance(transfer_methods, list): - raise ValueError("transfer_methods must be of list type") - for method in transfer_methods: - if method not in ['remote_url', 'local_file']: - raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") - - @classmethod - def is_external_data_tools_valid(cls, tenant_id: str, config: dict): - if 'external_data_tools' not in config or not config["external_data_tools"]: - config["external_data_tools"] = [] - - if not isinstance(config["external_data_tools"], list): - raise ValueError("external_data_tools must be of list type") - - for tool in config["external_data_tools"]: - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - - if not tool["enabled"]: - continue - - if "type" not in tool or not tool["type"]: - raise ValueError("external_data_tools[].type is required") - - type = tool["type"] - config = tool["config"] - - ExternalDataToolFactory.validate_config( - name=type, - tenant_id=tenant_id, - config=config - ) - - @classmethod - def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: - # Only check when mode is completion - if mode != 'completion': - return - - agent_mode = config.get("agent_mode", {}) - tools = agent_mode.get("tools", []) - dataset_exists = "dataset" in str(tools) - - dataset_query_variable = config.get("dataset_query_variable") - - if dataset_exists and not dataset_query_variable: - raise ValueError("Dataset query variable is required when dataset is exist") - - @classmethod - def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: - # prompt_type - if 'prompt_type' not in config or not config["prompt_type"]: - config["prompt_type"] = "simple" - - if config['prompt_type'] not in ['simple', 'advanced']: - raise ValueError("prompt_type must be in ['simple', 'advanced']") - - # chat_prompt_config - if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: - config["chat_prompt_config"] = {} - - if not isinstance(config["chat_prompt_config"], dict): - raise ValueError("chat_prompt_config must be of object type") - - # completion_prompt_config - if 'completion_prompt_config' not in config or not config["completion_prompt_config"]: - config["completion_prompt_config"] = {} - - if not isinstance(config["completion_prompt_config"], dict): - raise ValueError("completion_prompt_config must be of object type") - - # dataset_configs - if 'dataset_configs' not in config or not config["dataset_configs"]: - config["dataset_configs"] = {'retrieval_model': 'single'} - - if 'datasets' not in config["dataset_configs"] or not config["dataset_configs"]["datasets"]: - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if config["dataset_configs"]['retrieval_model'] == 'multiple': - if not config["dataset_configs"]['reranking_model']: - raise ValueError("reranking_model has not been set") - if not isinstance(config["dataset_configs"]['reranking_model'], dict): - raise ValueError("reranking_model must be of object type") - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if config['prompt_type'] == 'advanced': - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") - - if config['model']["mode"] not in ['chat', 'completion']: - raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") - - if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion": - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] - - if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' - - if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' - - if config['model']["mode"] == "chat": - prompt_list = config['chat_prompt_config']['prompt'] - - if len(prompt_list) > 10: - raise ValueError("prompt messages must be less than 10") + raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/app_service.py b/api/services/app_service.py new file mode 100644 index 00000000000000..6011b6a667935d --- /dev/null +++ b/api/services/app_service.py @@ -0,0 +1,338 @@ +import json +import logging +from datetime import datetime +from typing import cast + +import yaml +from flask_sqlalchemy.pagination import Pagination + +from constants.model_template import default_app_templates +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, AppModelConfig +from services.workflow_service import WorkflowService + + +class AppService: + def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination: + """ + Get app list with pagination + :param tenant_id: tenant id + :param args: request args + :return: + """ + filters = [ + App.tenant_id == tenant_id, + App.is_universal == False + ] + + if args['mode'] == 'workflow': + filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) + elif args['mode'] == 'chat': + filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) + elif args['mode'] == 'agent-chat': + filters.append(App.mode == AppMode.AGENT_CHAT.value) + elif args['mode'] == 'channel': + filters.append(App.mode == AppMode.CHANNEL.value) + + if 'name' in args and args['name']: + name = args['name'][:30] + filters.append(App.name.ilike(f'%{name}%')) + + app_models = db.paginate( + db.select(App).where(*filters).order_by(App.created_at.desc()), + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return app_models + + def create_app(self, tenant_id: str, args: dict, account: Account) -> App: + """ + Create app + :param tenant_id: tenant id + :param args: request args + :param account: Account instance + """ + app_mode = AppMode.value_of(args['mode']) + app_template = default_app_templates[app_mode] + + # get model config + default_model_config = app_template.get('model_config') + if default_model_config and 'model' in default_model_config: + # get model provider + model_manager = ModelManager() + + # get default model instance + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=account.current_tenant_id, + model_type=ModelType.LLM + ) + except (ProviderTokenNotInitError, LLMBadRequestError): + model_instance = None + except Exception as e: + logging.exception(e) + model_instance = None + + if model_instance: + if model_instance.model == default_model_config['model']['name']: + default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] + + default_model_config['model'] = json.dumps(default_model_dict) + + app = App(**app_template['app']) + app.name = args['name'] + app.description = args.get('description', '') + app.mode = args['mode'] + app.icon = args['icon'] + app.icon_background = args['icon_background'] + app.tenant_id = tenant_id + + db.session.add(app) + db.session.flush() + + if default_model_config: + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() + + app.app_model_config_id = app_model_config.id + + db.session.commit() + + app_was_created.send(app, account=account) + + return app + + def import_app(self, tenant_id: str, data: str, args: dict, account: Account) -> App: + """ + Import app + :param tenant_id: tenant id + :param data: import data + :param args: request args + :param account: Account instance + """ + try: + import_data = yaml.safe_load(data) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML format in data argument.") + + app_data = import_data.get('app') + model_config_data = import_data.get('model_config') + workflow = import_data.get('workflow') + + if not app_data: + raise ValueError("Missing app in data argument") + + app_mode = AppMode.value_of(app_data.get('mode')) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if not workflow: + raise ValueError("Missing workflow in data argument " + "when app mode is advanced-chat or workflow") + elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT]: + if not model_config_data: + raise ValueError("Missing model_config in data argument " + "when app mode is chat or agent-chat") + else: + raise ValueError("Invalid app mode") + + app = App( + tenant_id=tenant_id, + mode=app_data.get('mode'), + name=args.get("name") if args.get("name") else app_data.get('name'), + description=args.get("description") if args.get("description") else app_data.get('description', ''), + icon=args.get("icon") if args.get("icon") else app_data.get('icon'), + icon_background=args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background'), + enable_site=True, + enable_api=True + ) + + db.session.add(app) + db.session.commit() + + app_was_created.send(app, account=account) + + if workflow: + # init draft workflow + workflow_service = WorkflowService() + draft_workflow = workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow.get('graph'), + features=workflow.get('features'), + account=account + ) + workflow_service.publish_workflow( + app_model=app, + account=account, + draft_workflow=draft_workflow + ) + + if model_config_data: + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + + app_model_config_was_updated.send( + app, + app_model_config=app_model_config + ) + + return app + + def export_app(self, app: App) -> str: + """ + Export app + :param app: App instance + :return: + """ + app_mode = AppMode.value_of(app.mode) + + export_data = { + "app": { + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background + } + } + + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app.workflow_id: + workflow = app.workflow + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } + else: + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app) + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } + else: + app_model_config = app.app_model_config + + export_data['model_config'] = app_model_config.to_dict() + + return yaml.dump(export_data) + + def update_app(self, app: App, args: dict) -> App: + """ + Update app + :param app: App instance + :param args: request args + :return: App instance + """ + app.name = args.get('name') + app.description = args.get('description', '') + app.icon = args.get('icon') + app.icon_background = args.get('icon_background') + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_name(self, app: App, name: str) -> App: + """ + Update app name + :param app: App instance + :param name: new name + :return: App instance + """ + app.name = name + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + """ + Update app icon + :param app: App instance + :param icon: new icon + :param icon_background: new icon_background + :return: App instance + """ + app.icon = icon + app.icon_background = icon_background + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_site_status(self, app: App, enable_site: bool) -> App: + """ + Update app site status + :param app: App instance + :param enable_site: enable site status + :return: App instance + """ + if enable_site == app.enable_site: + return app + + app.enable_site = enable_site + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_api_status(self, app: App, enable_api: bool) -> App: + """ + Update app api status + :param app: App instance + :param enable_api: enable api status + :return: App instance + """ + if enable_api == app.enable_api: + return app + + app.enable_api = enable_api + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def delete_app(self, app: App) -> None: + """ + Delete app + :param app: App instance + """ + db.session.delete(app) + db.session.commit() + + app_was_deleted.send(app) + + # todo async delete related data by event + # app_model_configs, site, api_tokens, installed_apps, recommended_apps BY app + # app_annotation_hit_histories, app_annotation_settings, app_dataset_joins BY app + # workflows, workflow_runs, workflow_node_executions, workflow_app_logs BY app + # conversations, pinned_conversations, messages BY app + # message_feedbacks, message_annotations, message_chains BY message + # message_agent_thoughts, message_files, saved_messages BY message diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a9fe65df6fbb25..d013a51c3e6507 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,6 +5,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from models.model import App, AppMode, AppModelConfig from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, @@ -20,7 +21,21 @@ class AudioService: @classmethod - def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[str] = None): + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("Speech to text is not enabled") + + features_dict = workflow.features_dict + if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + raise ValueError("Speech to text is not enabled") + else: + app_model_config: AppModelConfig = app_model.app_model_config + + if not app_model_config.speech_to_text_dict['enabled']: + raise ValueError("Speech to text is not enabled") + if file is None: raise NoAudioUploadedServiceError() @@ -37,7 +52,7 @@ def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[st model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: @@ -49,17 +64,42 @@ def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[st return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, tenant_id: str, text: str, voice: str, streaming: bool, end_user: Optional[str] = None): + def transcript_tts(cls, app_model: App, text: str, streaming: bool, + voice: Optional[str] = None, end_user: Optional[str] = None): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("TTS is not enabled") + + features_dict = workflow.features_dict + if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + raise ValueError("TTS is not enabled") + + voice = features_dict['text_to_speech'].get('voice') if voice is None else voice + else: + text_to_speech_dict = app_model.app_model_config.text_to_speech_dict + + if not text_to_speech_dict.get('enabled'): + raise ValueError("TTS is not enabled") + + voice = text_to_speech_dict.get('voice') if voice is None else voice + model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() try: - return model_instance.invoke_tts(content_text=text.strip(), user=end_user, streaming=streaming, tenant_id=tenant_id, voice=voice) + return model_instance.invoke_tts( + content_text=text.strip(), + user=end_user, + streaming=streaming, + tenant_id=app_model.tenant_id, + voice=voice + ) except Exception as e: raise e diff --git a/api/services/completion_service.py b/api/services/completion_service.py index cbfbe9ef416b63..eb31ccbb3bf1ed 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -1,258 +1,71 @@ -import json from collections.abc import Generator from typing import Any, Union -from sqlalchemy import and_ - -from core.application_manager import ApplicationManager -from core.entities.application_entities import InvokeFrom -from core.file.message_file_parser import MessageFileParser -from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message -from services.app_model_config_service import AppModelConfigService -from services.errors.app import MoreLikeThisDisabledError -from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import MessageNotExistsError +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from models.model import Account, App, AppMode, EndUser class CompletionService: @classmethod def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, - invoke_from: InvokeFrom, streaming: bool = True, - is_model_config_override: bool = False) -> Union[dict, Generator]: - # is streaming mode - inputs = args['inputs'] - query = args['query'] - files = args['files'] if 'files' in args and args['files'] else [] - auto_generate_name = args['auto_generate_name'] \ - if 'auto_generate_name' in args else True - - if app_model.mode != 'completion': - if not query: - raise ValueError('query is required') - - if query: - if not isinstance(query, str): - raise ValueError('query must be a string') - - query = query.replace('\x00', '') - - conversation_id = args['conversation_id'] if 'conversation_id' in args else None - - conversation = None - if conversation_id: - conversation_filter = [ - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id, - Conversation.status == 'normal' - ] - - if isinstance(user, Account): - conversation_filter.append(Conversation.from_account_id == user.id) - else: - conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) - - conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() - - if not conversation: - raise ConversationNotExistsError() - - if conversation.status != 'normal': - raise ConversationCompletedError() - - if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() - - if not app_model_config: - raise AppModelConfigBrokenError() - else: - conversation_override_model_configs = json.loads(conversation.override_model_configs) - - app_model_config = AppModelConfig( - id=conversation.app_model_config_id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) - - if is_model_config_override: - # build new app model config - if 'model' not in args['model_config']: - raise ValueError('model_config.model is required') - - if 'completion_params' not in args['model_config']['model']: - raise ValueError('model_config.model.completion_params is required') - - completion_params = AppModelConfigService.validate_model_completion_params( - cp=args['model_config']['model']['completion_params'], - model_name=app_model_config.model_dict["name"] - ) - - app_model_config_model = app_model_config.model_dict - app_model_config_model['completion_params'] = completion_params - app_model_config.retriever_resource = json.dumps({'enabled': True}) - - app_model_config = app_model_config.copy() - app_model_config.model = json.dumps(app_model_config_model) + invoke_from: InvokeFrom, streaming: bool = True) -> Union[dict, Generator]: + """ + App Completion + :param app_model: app model + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + if app_model.mode == AppMode.COMPLETION.value: + return CompletionAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + return AgentChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.CHAT.value: + return ChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) else: - if app_model.app_model_config_id is None: - raise AppModelConfigBrokenError() - - app_model_config = app_model.app_model_config - - if not app_model_config: - raise AppModelConfigBrokenError() - - if is_model_config_override: - if not isinstance(user, Account): - raise Exception("Only account can override model config") - - # validate config - model_config = AppModelConfigService.validate_configuration( - tenant_id=app_model.tenant_id, - account=user, - config=args['model_config'], - app_mode=app_model.mode - ) - - app_model_config = AppModelConfig( - id=app_model_config.id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(model_config) - - # clean input by app_model_config form rules - inputs = cls.get_cleaned_inputs(inputs, app_model_config) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - app_model_config, - user - ) - - application_manager = ApplicationManager() - return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=is_model_config_override, - user=user, - invoke_from=invoke_from, - inputs=inputs, - query=query, - files=file_objs, - conversation=conversation, - stream=streaming, - extras={ - "auto_generate_conversation_name": auto_generate_name - } - ) + raise ValueError('Invalid app mode') @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ -> Union[dict, Generator]: - if not user: - raise ValueError('user cannot be None') - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() - - if not message: - raise MessageNotExistsError() - - current_app_model_config = app_model.app_model_config - more_like_this = current_app_model_config.more_like_this_dict - - if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: - raise MoreLikeThisDisabledError() - - app_model_config = message.app_model_config - model_dict = app_model_config.model_dict - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - app_model_config.model = json.dumps(model_dict) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.transform_message_files( - message.files, app_model_config - ) - - application_manager = ApplicationManager() - return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=True, + """ + Generate more like this + :param app_model: app model + :param user: user + :param message_id: message id + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + return CompletionAppGenerator().generate_more_like_this( + app_model=app_model, + message_id=message_id, user=user, invoke_from=invoke_from, - inputs=message.inputs, - query=message.query, - files=file_objs, - conversation=None, - stream=streaming, - extras={ - "auto_generate_conversation_name": False - } + stream=streaming ) - - @classmethod - def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - - # Filter input variables from form configuration, handle required fields, default values, and option values - input_form_config = app_model_config.user_input_form_list - for config in input_form_config: - input_config = list(config.values())[0] - variable = input_config["variable"] - - input_type = list(config.keys())[0] - - if variable not in user_inputs or not user_inputs[variable]: - if input_type == "external_data_tool": - continue - if "required" in input_config and input_config["required"]: - raise ValueError(f"{variable} is required in input form") - else: - filtered_inputs[variable] = input_config["default"] if "default" in input_config else "" - continue - - value = user_inputs[variable] - - if value: - if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - - if input_type == "select": - options = input_config["options"] if "options" in input_config else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: - if 'max_length' in input_config: - max_length = input_config['max_length'] - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - filtered_inputs[variable] = value.replace('\x00', '') if value else None - - return filtered_inputs diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ac3df380b265a7..1a0213799e619a 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account diff --git a/api/services/message_service.py b/api/services/message_service.py index ad2ff60f6b83c7..20918a8781bed3 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType diff --git a/api/services/workflow/__init__.py b/api/services/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py new file mode 100644 index 00000000000000..953c5c5a3cdb73 --- /dev/null +++ b/api/services/workflow/workflow_converter.py @@ -0,0 +1,608 @@ +import json +from typing import Optional + +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, + EasyUIBasedAppConfig, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, +) +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.workflow.entities.node_entities import NodeType +from events.app_event import app_was_created +from extensions.ext_database import db +from models.account import Account +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import App, AppMode, AppModelConfig +from models.workflow import Workflow, WorkflowType + + +class WorkflowConverter: + """ + App Convert to Workflow Mode + """ + + def convert_to_workflow(self, app_model: App, account: Account) -> App: + """ + Convert app to workflow + + - basic mode of chatbot app + + - expert mode of chatbot app + + - completion app + + :param app_model: App instance + :param account: Account + :return: new App instance + """ + # convert app model config + workflow = self.convert_app_model_config_to_workflow( + app_model=app_model, + app_model_config=app_model.app_model_config, + account_id=account.id + ) + + # create new app + new_app = App() + new_app.tenant_id = app_model.tenant_id + new_app.name = app_model.name + '(workflow)' + new_app.mode = AppMode.ADVANCED_CHAT.value \ + if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.workflow_id = workflow.id + new_app.icon = app_model.icon + new_app.icon_background = app_model.icon_background + new_app.enable_site = app_model.enable_site + new_app.enable_api = app_model.enable_api + new_app.api_rpm = app_model.api_rpm + new_app.api_rph = app_model.api_rph + new_app.is_demo = False + new_app.is_public = app_model.is_public + db.session.add(new_app) + db.session.commit() + + app_was_created.send(new_app, account=account) + + return new_app + + def convert_app_model_config_to_workflow(self, app_model: App, + app_model_config: AppModelConfig, + account_id: str) -> Workflow: + """ + Convert app model config to workflow mode + :param app_model: App instance + :param app_model_config: AppModelConfig instance + :param account_id: Account ID + :return: + """ + # get new app mode + new_app_mode = self._get_new_app_mode(app_model) + + # convert app model config + app_config = self._convert_to_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + + # init workflow graph + graph = { + "nodes": [], + "edges": [] + } + + # Convert list: + # - variables -> start + # - model_config -> llm + # - prompt_template -> llm + # - file_upload -> llm + # - external_data_variables -> http-request + # - dataset -> knowledge-retrieval + # - show_retrieve_source -> knowledge-retrieval + + # convert to start node + start_node = self._convert_to_start_node( + variables=app_config.variables + ) + + graph['nodes'].append(start_node) + + # convert to http request node + if app_config.external_data_variables: + http_request_nodes = self._convert_to_http_request_node( + app_model=app_model, + variables=app_config.variables, + external_data_variables=app_config.external_data_variables + ) + + for http_request_node in http_request_nodes: + graph = self._append_node(graph, http_request_node) + + # convert to knowledge retrieval node + if app_config.dataset: + knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=app_config.dataset + ) + + if knowledge_retrieval_node: + graph = self._append_node(graph, knowledge_retrieval_node) + + # convert to llm node + llm_node = self._convert_to_llm_node( + new_app_mode=new_app_mode, + graph=graph, + model_config=app_config.model, + prompt_template=app_config.prompt_template, + file_upload=app_config.additional_features.file_upload + ) + + graph = self._append_node(graph, llm_node) + + if new_app_mode == AppMode.WORKFLOW: + # convert to end node by app mode + end_node = self._convert_to_end_node() + graph = self._append_node(graph, end_node) + else: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) + + app_model_config_dict = app_config.app_model_config_dict + + # features + if new_app_mode == AppMode.ADVANCED_CHAT: + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + else: + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + + # create workflow record + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(new_app_mode).value, + version='draft', + graph=json.dumps(graph), + features=json.dumps(features), + created_by=account_id, + created_at=app_model_config.created_at + ) + + db.session.add(workflow) + db.session.commit() + + return workflow + + def _convert_to_app_config(self, app_model: App, + app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_model.mode = AppMode.AGENT_CHAT.value + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + elif app_mode == AppMode.CHAT: + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + elif app_mode == AppMode.COMPLETION: + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + else: + raise ValueError("Invalid app mode") + + return app_config + + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: + """ + Convert to Start Node + :param variables: list of variables + :return: + """ + return { + "id": "start", + "position": None, + "data": { + "title": "START", + "type": NodeType.START.value, + "variables": [jsonable_encoder(v) for v in variables] + } + } + + def _convert_to_http_request_node(self, app_model: App, + variables: list[VariableEntity], + external_data_variables: list[ExternalDataVariableEntity]) -> list[dict]: + """ + Convert API Based Extension to HTTP Request Node + :param app_model: App instance + :param variables: list of variables + :param external_data_variables: list of external data variables + :return: + """ + index = 1 + nodes = [] + tenant_id = app_model.tenant_id + for external_data_variable in external_data_variables: + tool_type = external_data_variable.type + if tool_type != "api": + continue + + tool_variable = external_data_variable.variable + tool_config = external_data_variable.config + + # get params from config + api_based_extension_id = tool_config.get("api_based_extension_id") + + # get api_based_extension + api_based_extension = self._get_api_based_extension( + tenant_id=tenant_id, + api_based_extension_id=api_based_extension_id + ) + + if not api_based_extension: + raise ValueError("[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid" + .format(tool_variable)) + + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=tenant_id, + token=api_based_extension.api_key + ) + + http_request_variables = [] + inputs = {} + for v in variables: + http_request_variables.append({ + "variable": v.variable, + "value_selector": ["start", v.variable] + }) + + inputs[v.variable] = '{{' + v.variable + '}}' + + if app_model.mode == AppMode.CHAT.value: + http_request_variables.append({ + "variable": "_query", + "value_selector": ["start", "sys.query"] + }) + + request_body = { + 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + 'params': { + 'app_id': app_model.id, + 'tool_variable': tool_variable, + 'inputs': inputs, + 'query': '{{_query}}' if app_model.mode == AppMode.CHAT.value else '' + } + } + + request_body_json = json.dumps(request_body) + request_body_json = request_body_json.replace('\{\{', '{{').replace('\}\}', '}}') + + http_request_node = { + "id": f"http-request-{index}", + "position": None, + "data": { + "title": f"HTTP REQUEST {api_based_extension.name}", + "type": NodeType.HTTP_REQUEST.value, + "variables": http_request_variables, + "method": "post", + "url": api_based_extension.api_endpoint, + "authorization": { + "type": "api-key", + "config": { + "type": "bearer", + "api_key": api_key + } + }, + "headers": "", + "params": "", + "body": { + "type": "json", + "data": request_body_json + } + } + } + + nodes.append(http_request_node) + + # append code node for response body parsing + code_node = { + "id": f"code-{index}", + "position": None, + "data": { + "title": f"Parse {api_based_extension.name} Response", + "type": NodeType.CODE.value, + "variables": [{ + "variable": "response_json", + "value_selector": [http_request_node['id'], "body"] + }], + "code_language": "python3", + "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" + "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "string" + } + ] + } + } + + nodes.append(code_node) + index += 1 + + return nodes + + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ + -> Optional[dict]: + """ + Convert datasets to Knowledge Retrieval Node + :param new_app_mode: new app mode + :param dataset_config: dataset + :return: + """ + retrieve_config = dataset_config.retrieve_config + if new_app_mode == AppMode.CHAT: + query_variable_selector = ["start", "sys.query"] + elif retrieve_config.query_variable: + # fetch query variable + query_variable_selector = ["start", retrieve_config.query_variable] + else: + return None + + return { + "id": "knowledge-retrieval", + "position": None, + "data": { + "title": "KNOWLEDGE RETRIEVAL", + "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "query_variable_selector": query_variable_selector, + "dataset_ids": dataset_config.dataset_ids, + "retrieval_mode": retrieve_config.retrieve_strategy.value, + "multiple_retrieval_config": { + "top_k": retrieve_config.top_k, + "score_threshold": retrieve_config.score_threshold, + "reranking_model": retrieve_config.reranking_model + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + else None, + } + } + + def _convert_to_llm_node(self, new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileUploadEntity] = None) -> dict: + """ + Convert to LLM Node + :param new_app_mode: new app mode + :param graph: graph + :param model_config: model config + :param prompt_template: prompt template + :param file_upload: file upload config (optional) + """ + # fetch start and knowledge retrieval node + start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) + knowledge_retrieval_node = next(filter( + lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, + graph['nodes'] + ), None) + + role_prefix = None + + # Chat Model + if model_config.mode == LLMMode.CHAT.value: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = [ + { + "role": 'user', + "text": prompt_template_config['prompt_template'].template + } + ] + else: + advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template + prompts = [{ + "role": m.role.value, + "text": m.text + } for m in advanced_chat_prompt_template.messages] \ + if advanced_chat_prompt_template else [] + # Completion Model + else: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = { + "text": prompt_template_config['prompt_template'].template + } + + prompt_rules = prompt_template_config['prompt_rules'] + role_prefix = { + "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + } + else: + advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template + prompts = { + "text": advanced_completion_prompt_template.prompt, + } if advanced_completion_prompt_template else {"text": ""} + + if advanced_completion_prompt_template.role_prefix: + role_prefix = { + "user": advanced_completion_prompt_template.role_prefix.user, + "assistant": advanced_completion_prompt_template.role_prefix.assistant + } + + memory = None + if new_app_mode == AppMode.CHAT: + memory = { + "role_prefix": role_prefix, + "window": { + "enabled": False + } + } + + return { + "id": "llm", + "position": None, + "data": { + "title": "LLM", + "type": NodeType.LLM.value, + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": model_config.parameters.update({"stop": model_config.stop}) + }, + "variables": [{ + "variable": v['variable'], + "value_selector": ["start", v['variable']] + } for v in start_node['data']['variables']], + "prompts": prompts, + "memory": memory, + "context": { + "enabled": knowledge_retrieval_node is not None, + "variable_selector": ["knowledge-retrieval", "result"] + if knowledge_retrieval_node is not None else None + }, + "vision": { + "enabled": file_upload is not None, + "variable_selector": ["start", "sys.files"] if file_upload is not None else None, + "configs": { + "detail": file_upload.image_config['detail'] + } if file_upload is not None else None + } + } + } + + def _convert_to_end_node(self) -> dict: + """ + Convert to End Node + :return: + """ + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": [{ + "variable": "result", + "value_selector": ["llm", "text"] + }] + } + } + + def _convert_to_answer_node(self) -> dict: + """ + Convert to Answer Node + :return: + """ + # for original chat app + return { + "id": "answer", + "position": None, + "data": { + "title": "ANSWER", + "type": NodeType.ANSWER.value, + "variables": { + "variable": "text", + "value_selector": ["llm", "text"] + }, + "answer": "{{text}}" + } + } + + def _create_edge(self, source: str, target: str) -> dict: + """ + Create Edge + :param source: source node id + :param target: target node id + :return: + """ + return { + "id": f"{source}-{target}", + "source": source, + "target": target + } + + def _append_node(self, graph: dict, node: dict) -> dict: + """ + Append Node to Graph + + :param graph: Graph, include: nodes, edges + :param node: Node to append + :return: + """ + previous_node = graph['nodes'][-1] + graph['nodes'].append(node) + graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + return graph + + def _get_new_app_mode(self, app_model: App) -> AppMode: + """ + Get new app mode + :param app_model: App instance + :return: AppMode + """ + if app_model.mode == AppMode.COMPLETION.value: + return AppMode.WORKFLOW + else: + return AppMode.ADVANCED_CHAT + + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + """ + Get API Based Extension + :param tenant_id: tenant id + :param api_based_extension_id: api based extension id + :return: + """ + return db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py new file mode 100644 index 00000000000000..047678837509e2 --- /dev/null +++ b/api/services/workflow_app_service.py @@ -0,0 +1,62 @@ +from flask_sqlalchemy.pagination import Pagination +from sqlalchemy import and_, or_ + +from extensions.ext_database import db +from models import CreatedByRole +from models.model import App, EndUser +from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus + + +class WorkflowAppService: + + def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: + """ + Get paginate workflow app logs + :param app: app model + :param args: request args + :return: + """ + query = ( + db.select(WorkflowAppLog) + .where( + WorkflowAppLog.tenant_id == app_model.tenant_id, + WorkflowAppLog.app_id == app_model.id + ) + ) + + status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None + if args['keyword'] or status: + query = query.join( + WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id + ) + + if args['keyword']: + keyword_val = f"%{args['keyword'][:30]}%" + keyword_conditions = [ + WorkflowRun.inputs.ilike(keyword_val), + WorkflowRun.outputs.ilike(keyword_val), + # filter keyword by end user session id if created by end user role + and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_val)) + ] + + query = query.outerjoin( + EndUser, + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + ).filter(or_(*keyword_conditions)) + + if status: + # join with workflow_run and filter by status + query = query.filter( + WorkflowRun.status == status.value + ) + + query = query.order_by(WorkflowAppLog.created_at.desc()) + + pagination = db.paginate( + query, + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return pagination diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py new file mode 100644 index 00000000000000..1d3f93f2247d19 --- /dev/null +++ b/api/services/workflow_run_service.py @@ -0,0 +1,93 @@ +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.model import App +from models.workflow import ( + WorkflowNodeExecution, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunTriggeredFrom, +) + + +class WorkflowRunService: + def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get('limit', 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + ) + + if args.get('last_id'): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get('last_id'), + ).first() + + if not last_workflow_run: + raise ValueError('Last workflow run not exists') + + workflow_runs = base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id + ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + else: + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination( + data=workflow_runs, + limit=limit, + has_more=has_more + ) + + def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ).first() + + return workflow_run + + def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_workflow_run(app_model, run_id) + + if not workflow_run: + return [] + + node_executions = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ).order_by(WorkflowNodeExecution.index.desc()).all() + + return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py new file mode 100644 index 00000000000000..55f2526fbfc827 --- /dev/null +++ b/api/services/workflow_service.py @@ -0,0 +1,358 @@ +import json +import time +from collections.abc import Generator +from datetime import datetime +from typing import Optional, Union + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeType +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, EndUser +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowType, +) +from services.workflow.workflow_converter import WorkflowConverter + + +class WorkflowService: + """ + Workflow Service + """ + + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() + + # return draft workflow + return workflow + + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get published workflow + """ + + if not app_model.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id + ).first() + + return workflow + + def sync_draft_workflow(self, app_model: App, + graph: dict, + features: dict, + account: Account) -> Workflow: + """ + Sync draft workflow + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + # validate features structure + self.validate_features_structure( + app_model=app_model, + features=features + ) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_model.mode).value, + version='draft', + graph=json.dumps(graph), + features=json.dumps(features), + created_by=account.id + ) + db.session.add(workflow) + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = datetime.utcnow() + + # commit db session changes + db.session.commit() + + # return draft workflow + return workflow + + def publish_workflow(self, app_model: App, + account: Account, + draft_workflow: Optional[Workflow] = None) -> Workflow: + """ + Publish workflow from draft + + :param app_model: App instance + :param account: Account instance + :param draft_workflow: Workflow instance + """ + if not draft_workflow: + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('No valid workflow found.') + + # TODO check if the workflow structure is valid + + # create new workflow + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=draft_workflow.type, + version=str(datetime.utcnow()), + graph=draft_workflow.graph, + created_by=account.id + ) + + # commit db session changes + db.session.add(workflow) + db.session.commit() + + app_model.workflow_id = workflow.id + db.session.commit() + + # TODO update app related datasets + + # return new workflow + return workflow + + def get_default_block_configs(self) -> list[dict]: + """ + Get default block configs + """ + # return default block config + workflow_engine_manager = WorkflowEngineManager() + return workflow_engine_manager.get_default_configs() + + def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_type = NodeType.value_of(node_type) + + # return default block config + workflow_engine_manager = WorkflowEngineManager() + return workflow_engine_manager.get_default_config(node_type, filters) + + def run_advanced_chat_draft_workflow(self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom) -> Union[dict, Generator]: + """ + Run advanced chatbot draft workflow + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow + app_generator = AdvancedChatAppGenerator() + response = app_generator.generate( + app_model=app_model, + workflow=draft_workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=True + ) + + return response + + def run_draft_workflow(self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom) -> Union[dict, Generator]: + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow + app_generator = WorkflowAppGenerator() + response = app_generator.generate( + app_model=app_model, + workflow=draft_workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=True + ) + + return response + + def stop_workflow_task(self, task_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom) -> None: + """ + Stop workflow task + """ + AppQueueManager.set_stop_flag(task_id, invoke_from, user.id) + + def run_draft_workflow_node(self, app_model: App, + node_id: str, + user_inputs: dict, + account: Account) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow node + workflow_engine_manager = WorkflowEngineManager() + start_at = time.perf_counter() + + try: + node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + except WorkflowNodeRunFailedError as e: + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=e.node_id, + node_type=e.node_type.value, + title=e.node_title, + status=WorkflowNodeExecutionStatus.FAILED.value, + error=e.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, + process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, + outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, + execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) + if node_run_result.metadata else None), + status=WorkflowNodeExecutionStatus.SUCCEEDED.value, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + else: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + status=node_run_result.status.value, + error=node_run_result.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def convert_to_workflow(self, app_model: App, account: Account) -> App: + """ + Basic mode of chatbot app(expert mode) to workflow + Completion App to Workflow App + + :param app_model: App instance + :param account: Account instance + :return: + """ + # chatbot convert to workflow mode + workflow_converter = WorkflowConverter() + + if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: + raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + + # convert to workflow + new_app = workflow_converter.convert_to_workflow( + app_model=app_model, + account=account + ) + + return new_app + + def validate_features_structure(self, app_model: App, features: dict) -> dict: + if app_model.mode == AppMode.ADVANCED_CHAT.value: + return AdvancedChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=features, + only_structure_validate=True + ) + elif app_model.mode == AppMode.WORKFLOW.value: + return WorkflowAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=features, + only_structure_validate=True + ) + else: + raise ValueError(f"Invalid app mode: {app_model.mode}") diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 04abacf73d2c17..dd1baa79d4ec9d 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -66,4 +66,8 @@ JINA_API_KEY= OLLAMA_BASE_URL= # Mock Switch -MOCK_SWITCH=false \ No newline at end of file +MOCK_SWITCH=false + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT= +CODE_EXECUTINO_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/__init__.py b/api/tests/integration_tests/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/workflow/nodes/__init__.py b/api/tests/integration_tests/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py new file mode 100644 index 00000000000000..2eb987181fa93a --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -0,0 +1,31 @@ +import os +import pytest + +from typing import Literal +from _pytest.monkeypatch import MonkeyPatch +from core.helper.code_executor.code_executor import CodeExecutor + +MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + +class MockedCodeExecutor: + @classmethod + def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + # invoke directly + if language == 'python3': + return { + "result": 3 + } + elif language == 'jinja2': + return { + "result": "3" + } + +@pytest.fixture +def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(CodeExecutor, "execute_code", MockedCodeExecutor.invoke) + yield + monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py new file mode 100644 index 00000000000000..9cc43031f3a05c --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -0,0 +1,85 @@ +import os +import pytest +import requests.api as requests +import httpx._api as httpx +from requests import Response as RequestsResponse +from httpx import Request as HttpxRequest +from yarl import URL + +from typing import Literal +from _pytest.monkeypatch import MonkeyPatch +from json import dumps + +MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + +class MockedHttp: + def requests_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str, + **kwargs) -> RequestsResponse: + """ + Mocked requests.request + """ + response = RequestsResponse() + response.url = str(URL(url) % kwargs.get('params', {})) + response.headers = kwargs.get('headers', {}) + + if url == 'http://404.com': + response.status_code = 404 + response._content = b'Not Found' + return response + + # get data, files + data = kwargs.get('data', None) + files = kwargs.get('files', None) + + if data is not None: + resp = dumps(data).encode('utf-8') + if files is not None: + resp = dumps(files).encode('utf-8') + else: + resp = b'OK' + + response.status_code = 200 + response._content = resp + return response + + def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + url: str, **kwargs) -> httpx.Response: + """ + Mocked httpx.request + """ + response = httpx.Response( + status_code=200, + request=HttpxRequest(method, url) + ) + response.headers = kwargs.get('headers', {}) + + if url == 'http://404.com': + response.status_code = 404 + response.content = b'Not Found' + return response + + # get data, files + data = kwargs.get('data', None) + files = kwargs.get('files', None) + + if data is not None: + resp = dumps(data).encode('utf-8') + if files is not None: + resp = dumps(files).encode('utf-8') + else: + resp = b'OK' + + response.status_code = 200 + response._content = resp + return response + +@pytest.fixture +def setup_http_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(requests, "request", MockedHttp.requests_request) + monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) + yield + monkeypatch.undo() \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py new file mode 100644 index 00000000000000..0b7217b053b067 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -0,0 +1,266 @@ +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.code.code_node import CodeNode +from models.workflow import WorkflowNodeExecutionStatus +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'outputs': { + 'result': { + 'type': 'number', + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + + # execute node + result = node.run(pool) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] == 3 + assert result.error is None + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code_output_validator(setup_code_executor_mock): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + "outputs": { + "result": { + "type": "string", + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == 'result in output form must be a string' + +def test_execute_code_output_validator_depth(): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + } + } + } + } + } + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + } + ) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": "1", + "string_validator": 1, + "number_array_validator": ["1", "2", "3", "3.333"], + "string_array_validator": [1, 2, 3], + "object_validator": { + "result": "1", + "depth": { + "depth": { + "depth": "1" + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1" * 2000, + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333] * 2000, + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py new file mode 100644 index 00000000000000..8b94105b44527f --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -0,0 +1,313 @@ +from calendar import c +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode + +from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock + +BASIC_NODE_DATA = { + 'tenant_id': '1', + 'app_id': '1', + 'workflow_id': '1', + 'user_id': '1', + 'user_from': InvokeFrom.WEB_APP, +} + +# construct variable pool +pool = VariablePool(system_variables={}, user_inputs={}) +pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) +pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_get(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_no_auth(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'no-auth', + 'config': None, + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'X-Header: 123' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_custom_authorization_header(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'custom', + 'api_key': 'Auth', + 'header': 'X-Auth', + }, + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'X-Header: 123' in data + assert 'X-Auth: Auth' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_template(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'get', + 'url': 'http://example.com/{{args1}}', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123\nX-Header2:{{args1}}', + 'params': 'A:b\nTemplate:{{args1}}', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'Template=2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + assert 'X-Header2: 2' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_json(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'json', + 'data': '{"a": "{{args1}}"}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert '{"a": "1"}' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_x_www_form_urlencoded(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'x-www-form-urlencoded', + 'data': 'a:{{args1}}\nb:{{args2}}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'a=1&b=2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_form_data(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'form-data', + 'data': 'a:{{args1}}\nb:{{args2}}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'form-data; name="a"' in data + assert '1' in data + assert 'form-data; name="b"' in data + assert '2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_none_data(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'none', + 'data': '123123123' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + assert '123123123' not in data \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py new file mode 100644 index 00000000000000..999ebf77342601 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -0,0 +1,132 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderModelBundle, ProviderConfiguration +from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, CustomProviderConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.llm.llm_node import LLMNode +from extensions.ext_database import db +from models.provider import ProviderType +from models.workflow import WorkflowNodeExecutionStatus + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_execute_llm(setup_openai_mock): + node = LLMNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'llm', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'variables': [ + { + 'variable': 'weather', + 'value_selector': ['abc', 'output'], + }, + { + 'variable': 'query', + 'value_selector': ['sys', 'query'] + } + ], + 'prompt_template': [ + { + 'role': 'system', + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.' + }, + { + 'role': 'user', + 'text': '{{query}}' + } + ], + 'memory': { + 'window': { + 'enabled': True, + 'size': 2 + } + }, + 'context': { + 'enabled': False + }, + 'vision': { + 'enabled': False + } + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather today?', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION: 'abababa' + }, user_inputs={}) + pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + + credentials = { + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + + provider_instance = ModelProviderFactory().get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ) + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_config = ModelConfigWithCredentialsEntity( + model='gpt-3.5-turbo', + provider='openai', + mode='chat', + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), + provider_model_bundle=provider_model_bundle + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['text'] is not None + assert result.outputs['usage']['total_tokens'] > 0 diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py new file mode 100644 index 00000000000000..36cf0a070aa855 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -0,0 +1,46 @@ +import pytest + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = '''{{args2}}''' + node = TemplateTransformNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.END_USER, + config={ + 'id': '1', + 'data': { + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'template': code, + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=3) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['output'] == '3' diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py new file mode 100644 index 00000000000000..66139563e29429 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -0,0 +1,45 @@ +from core.app.entities.app_invoke_entities import InvokeFrom + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.tool.tool_node import ToolNode +from models.workflow import WorkflowNodeExecutionStatus + +def test_tool_invoke(): + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') + + node = ToolNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'title': 'a', + 'desc': 'a', + 'provider_id': 'maths', + 'provider_type': 'builtin', + 'provider_name': 'maths', + 'tool_name': 'eval_expression', + 'tool_label': 'eval_expression', + 'tool_configurations': {}, + 'tool_parameters': [ + { + 'variable': 'expression', + 'value_selector': ['1', '123', 'args1'], + 'variable_type': 'selector', + 'value': None + }, + ] + } + } + ) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert '2' in result.outputs['text'] + assert result.outputs['files'] == [] \ No newline at end of file diff --git a/api/tests/unit_tests/.gitignore b/api/tests/unit_tests/.gitignore new file mode 100644 index 00000000000000..426667562b31da --- /dev/null +++ b/api/tests/unit_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/unit_tests/__init__.py b/api/tests/unit_tests/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py new file mode 100644 index 00000000000000..afc9802cf1cbe7 --- /dev/null +++ b/api/tests/unit_tests/conftest.py @@ -0,0 +1,7 @@ +import os + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) diff --git a/api/tests/unit_tests/core/__init__.py b/api/tests/unit_tests/core/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/prompt/__init__.py b/api/tests/unit_tests/core/prompt/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py new file mode 100644 index 00000000000000..5c08b9f168ad20 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -0,0 +1,211 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity +from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from models.model import Conversation + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." + prompt_template_config = CompletionModelPromptTemplate( + text=prompt_template + ) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix( + user="Human", + assistant="Assistant" + ), + window=MemoryConfig.WindowConfig( + enabled=False + ) + ) + + inputs = { + "name": "John" + } + files = [] + context = "I am superman." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_completion_model_prompt_messages( + prompt_template=prompt_template_config, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 1 + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ + "#context#": context, + "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " + f"{prompt.content}" for prompt in history_prompt_messages]), + **inputs, + }) + + +def test__get_chat_model_prompt_messages(get_chat_model_args): + model_config_mock, memory_config, messages, inputs, context = get_chat_model_args + + files = [] + query = "Hi2." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi1."), + AssistantPromptMessage(content="Hello1!") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=query, + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 6 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=messages[0].text + ).format({**inputs, "#context#": context}) + assert prompt_messages[5].content == query + + +def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): + model_config_mock, _, messages, inputs, context = get_chat_model_args + + files = [] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=None, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 3 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=messages[0].text + ).format({**inputs, "#context#": context}) + + +def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): + model_config_mock, _, messages, inputs, context = get_chat_model_args + + files = [ + FileObj( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + url="https://example.com/image1.jpg", + file_upload_entity=FileUploadEntity( + image_config={ + "detail": "high", + } + ) + ) + ] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template=messages, + inputs=inputs, + query=None, + files=files, + context=context, + memory_config=None, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=messages[0].text + ).format({**inputs, "#context#": context}) + assert isinstance(prompt_messages[3].content, list) + assert len(prompt_messages[3].content) == 2 + assert prompt_messages[3].content[1].data == files[0].url + + +@pytest.fixture +def get_chat_model_args(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False + ) + ) + + prompt_messages = [ + ChatModelMessage( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM + ), + ChatModelMessage( + text="Hi.", + role=PromptMessageRole.USER + ), + ChatModelMessage( + text="Hello!", + role=PromptMessageRole.ASSISTANT + ) + ] + + inputs = { + "name": "John" + } + + context = "I am superman." + + return model_config_mock, memory_config, prompt_messages, inputs, context diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py new file mode 100644 index 00000000000000..9796fc5558110f --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +from core.app.app_config.entities import ModelConfigEntity +from core.entities.provider_configuration import ProviderModelBundle +from core.model_runtime.entities.message_entities import UserPromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.prompt_transform import PromptTransform + + +def test__calculate_rest_token(): + model_schema_mock = MagicMock(spec=AIModelEntity) + parameter_rule_mock = MagicMock(spec=ParameterRule) + parameter_rule_mock.name = 'max_tokens' + model_schema_mock.parameter_rules = [ + parameter_rule_mock + ] + model_schema_mock.model_properties = { + ModelPropertyKey.CONTEXT_SIZE: 62 + } + + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens.return_value = 6 + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.model = 'gpt-4' + model_config_mock.credentials = {} + model_config_mock.parameters = { + 'max_tokens': 50 + } + model_config_mock.model_schema = model_schema_mock + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + prompt_transform = PromptTransform() + + prompt_messages = [UserPromptMessage(content="Hello, how are you?")] + rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + + # Validate based on the mock configuration and expected logic + expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters['max_tokens'] + - large_language_model_mock.get_num_tokens.return_value) + assert rest_tokens == expected_rest_tokens + assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py new file mode 100644 index 00000000000000..be9fe8d004fa32 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -0,0 +1,248 @@ +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage +from core.prompt.simple_prompt_transform import SimplePromptTransform +from models.model import AppMode, Conversation + + +def test_get_common_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_baichuan_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_common_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_baichuan_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + print(prompt_template['prompt_template'].template) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_q(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] + assert prompt_template['special_variable_keys'] == ['#query#'] + + +def test_get_common_chat_app_prompt_template_with_cq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_p(): + prompt_transform = SimplePromptTransform() + pre_prompt = "you are {{name}}" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=False, + with_memory_prompt=False, + ) + assert prompt_template['prompt_template'].template == pre_prompt + '\n' + assert prompt_template['custom_variable_keys'] == ['name'] + assert prompt_template['special_variable_keys'] == [] + + +def test__get_chat_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + memory_mock = MagicMock(spec=TokenBufferMemory) + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory_mock.get_history_prompt_messages.return_value = history_prompt_messages + + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=memory_mock, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=False, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context} + real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].content == real_system_prompt + assert prompt_messages[1].content == history_prompt_messages[0].content + assert prompt_messages[2].content == history_prompt_messages[1].content + assert prompt_messages[3].content == query + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=memory, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + + prompt_rules = prompt_template['prompt_rules'] + full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( + max_token_limit=2000, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + )} + real_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 1 + assert stops == prompt_rules.get('stops') + assert prompt_messages[0].content == real_prompt diff --git a/api/tests/unit_tests/core/workflow/__init__.py b/api/tests/unit_tests/core/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/workflow/nodes/__init__.py b/api/tests/unit_tests/core/workflow/nodes/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py new file mode 100644 index 00000000000000..bad5d42a43e023 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -0,0 +1,56 @@ +from unittest.mock import MagicMock + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +def test_execute_answer(): + node = AnswerNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'answer', + 'data': { + 'title': '123', + 'type': 'answer', + 'variables': [ + { + 'value_selector': ['llm', 'text'], + 'variable': 'text' + }, + { + 'value_selector': ['start', 'weather'], + 'variable': 'weather' + }, + ], + 'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.' + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny') + pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.') + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + + +# TODO test files diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py new file mode 100644 index 00000000000000..7b402ad0a09193 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +def test_execute_if_else_result_true(): + node = IfElseNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'if-else', + 'data': { + 'title': '123', + 'type': 'if-else', + 'logical_operator': 'and', + 'conditions': [ + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'array_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'array_not_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'not_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'start with', + 'variable_selector': ['start', 'start_with'], + 'value': 'ab' + }, + { + 'comparison_operator': 'end with', + 'variable_selector': ['start', 'end_with'], + 'value': 'ab' + }, + { + 'comparison_operator': 'is', + 'variable_selector': ['start', 'is'], + 'value': 'ab' + }, + { + 'comparison_operator': 'is not', + 'variable_selector': ['start', 'is_not'], + 'value': 'ab' + }, + { + 'comparison_operator': 'empty', + 'variable_selector': ['start', 'empty'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not empty', + 'variable_selector': ['start', 'not_empty'], + 'value': 'ab' + }, + { + 'comparison_operator': '=', + 'variable_selector': ['start', 'equals'], + 'value': '22' + }, + { + 'comparison_operator': '≠', + 'variable_selector': ['start', 'not_equals'], + 'value': '22' + }, + { + 'comparison_operator': '>', + 'variable_selector': ['start', 'greater_than'], + 'value': '22' + }, + { + 'comparison_operator': '<', + 'variable_selector': ['start', 'less_than'], + 'value': '22' + }, + { + 'comparison_operator': '≥', + 'variable_selector': ['start', 'greater_than_or_equal'], + 'value': '22' + }, + { + 'comparison_operator': '≤', + 'variable_selector': ['start', 'less_than_or_equal'], + 'value': '22' + }, + { + 'comparison_operator': 'null', + 'variable_selector': ['start', 'null'] + }, + { + 'comparison_operator': 'not null', + 'variable_selector': ['start', 'not_null'] + }, + ] + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) + pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') + pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') + pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') + pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') + pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') + pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') + pool.append_variable(node_id='start', variable_key_list=['empty'], value='') + pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') + pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) + pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) + pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) + pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) + pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) + pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) + pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] is True + + +def test_execute_if_else_result_false(): + node = IfElseNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'if-else', + 'data': { + 'title': '123', + 'type': 'if-else', + 'logical_operator': 'or', + 'conditions': [ + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'array_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'array_not_contains'], + 'value': 'ab' + } + ] + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def']) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] is False diff --git a/api/tests/unit_tests/services/__init__.py b/api/tests/unit_tests/services/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/services/workflow/__init__.py b/api/tests/unit_tests/services/workflow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py new file mode 100644 index 00000000000000..0ca8ae135ce8dc --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -0,0 +1,446 @@ +# test for api/services/workflow/workflow_converter.py +import json +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ + DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ + AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity +from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import AppMode +from services.workflow.workflow_converter import WorkflowConverter + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity( + variable="text-input", + label="text-input", + type=VariableEntity.Type.TEXT_INPUT + ), + VariableEntity( + variable="paragraph", + label="paragraph", + type=VariableEntity.Type.PARAGRAPH + ), + VariableEntity( + variable="select", + label="select", + type=VariableEntity.Type.SELECT + ) + ] + + +def test__convert_to_start_node(default_variables): + # act + result = WorkflowConverter()._convert_to_start_node(default_variables) + + # assert + assert isinstance(result["data"]["variables"][0]["type"], str) + assert result["data"]["variables"][0]["type"] == "text-input" + assert result["data"]["variables"][0]["variable"] == "text-input" + assert result["data"]["variables"][1]["variable"] == "paragraph" + assert result["data"]["variables"][2]["variable"] == "select" + + +def test__convert_to_http_request_node_for_chatbot(default_variables): + """ + Test convert to http request nodes for chatbot + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.CHAT.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 4 # appended _query variable + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "{{_query}}" # for chatbot + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_http_request_node_for_workflow_app(default_variables): + """ + Test convert to http request nodes for workflow app + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.WORKFLOW.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 3 + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "" + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_knowledge_retrieval_node_for_chatbot(): + new_app_mode = AppMode.CHAT + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", "sys.query"] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_knowledge_retrieval_node_for_workflow_app(): + new_app_mode = AppMode.WORKFLOW + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable="query", + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ]) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], list) + assert len(llm_node["data"]["prompts"]) == len(prompt_template.advanced_chat_prompt_template.messages) + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.advanced_chat_prompt_template.messages[0].text + + +def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], dict) + assert llm_node["data"]["prompts"]['text'] == prompt_template.advanced_completion_prompt_template.prompt diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index afdabd078a4bbe..4f7965609b7089 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -50,6 +50,19 @@ services: AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' ports: - "8080:8080" + + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:latest + restart: always + cap_add: + - SYS_ADMIN + environment: + # The DifySandbox configurations + API_KEY: dify-sandbox + GIN_MODE: 'release' + ports: + - "8194:8194" # Qdrant vector store. # uncomment to use qdrant as vector store. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d627bb38481173..f066582ac8e66e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -122,6 +122,9 @@ services: SENTRY_TRACES_SAMPLE_RATE: 1.0 # The sample rate for Sentry profiles. Default: `1.0` SENTRY_PROFILES_SAMPLE_RATE: 1.0 + # The sandbox service endpoint. + CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" + CODE_EXECUTION_API_KEY: dify-sandbox depends_on: - db - redis @@ -286,6 +289,19 @@ services: # ports: # - "8080:8080" + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:latest + restart: always + cap_add: + - SYS_ADMIN + environment: + # The DifySandbox configurations + API_KEY: dify-sandbox + GIN_MODE: release + ports: + - "8194:8194" + # Qdrant vector store. # uncomment to use qdrant as vector store. # (if uncommented, you need to comment out the weaviate service above, diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index dfd6ec02095e35..4bc7fb77abe728 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -24,7 +24,23 @@ done if $api_modified; then echo "Running Ruff linter on api module" - ./dev/reformat + + # python style checks rely on `ruff` in path + if ! command -v ruff &> /dev/null; then + echo "Installing Ruff ..." + pip install ruff + fi + + ruff check ./api || status=$? + + status=${status:-0} + + + if [ $status -ne 0 ]; then + echo "Ruff linter on api module error, exit code: $status" + echo "Please run 'dev/reformat' to fix the fixable linting errors." + exit 1 + fi fi if $web_modified; then