From 30eb7c20690acdb43b86f3e9859137c448b0e609 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Tue, 24 Oct 2023 17:35:21 +0800 Subject: [PATCH 01/57] add form schemas. --- {.vscode => api/.vscode}/launch.json | 2 +- api/controllers/console/__init__.py | 2 +- api/controllers/console/extension.py | 23 ++++++++ api/core/__init__.py | 1 + api/core/helper/auto_register.py | 17 ++++++ api/core/helper/extensible.py | 34 ++++++++++++ api/core/moderation/__init__.py | 4 ++ api/core/moderation/api_based/__init__.py | 0 api/core/moderation/api_based/api_based.py | 13 +++++ api/core/moderation/base.py | 46 ++++++++++++++++ api/core/moderation/cloud_service/__init__.py | 0 .../moderation/cloud_service/cloud_service.py | 5 ++ api/core/moderation/cloud_service/schema.json | 51 ++++++++++++++++++ api/core/moderation/keywords/__init__.py | 0 api/core/moderation/keywords/keywords.py | 13 +++++ api/core/moderation/openai/__init__.py | 0 api/core/moderation/openai/openai.py | 8 +++ api/services/app_model_config_service.py | 52 +++++++++---------- api/services/extension_service.py | 7 +++ 19 files changed, 249 insertions(+), 29 deletions(-) rename {.vscode => api/.vscode}/launch.json (94%) create mode 100644 api/controllers/console/extension.py create mode 100644 api/core/helper/auto_register.py create mode 100644 api/core/helper/extensible.py create mode 100644 api/core/moderation/__init__.py create mode 100644 api/core/moderation/api_based/__init__.py create mode 100644 api/core/moderation/api_based/api_based.py create mode 100644 api/core/moderation/base.py create mode 100644 api/core/moderation/cloud_service/__init__.py create mode 100644 api/core/moderation/cloud_service/cloud_service.py create mode 100644 api/core/moderation/cloud_service/schema.json create mode 100644 api/core/moderation/keywords/__init__.py create mode 100644 api/core/moderation/keywords/keywords.py create mode 100644 api/core/moderation/openai/__init__.py create mode 100644 api/core/moderation/openai/openai.py create mode 100644 api/services/extension_service.py diff --git a/.vscode/launch.json b/api/.vscode/launch.json similarity index 94% rename from .vscode/launch.json rename to api/.vscode/launch.json index 515deb4c0cf85b..e3c1f797c61601 100644 --- a/.vscode/launch.json +++ b/api/.vscode/launch.json @@ -10,7 +10,7 @@ "request": "launch", "module": "flask", "env": { - "FLASK_APP": "api/app.py", + "FLASK_APP": "app.py", "FLASK_DEBUG": "1", "GEVENT_SUPPORT": "True" }, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 2476d918870c6f..ac881dc126c0d0 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -6,7 +6,7 @@ api = ExternalApi(bp) # Import other controllers -from . import setup, version, apikey, admin +from . import extension, setup, version, apikey, admin # Import app controllers from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py new file mode 100644 index 00000000000000..0b5ead214bf508 --- /dev/null +++ b/api/controllers/console/extension.py @@ -0,0 +1,23 @@ +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.login import login_required +from services.extension_service import ExtensionService + + +class CodeBasedExtension(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('module', type=str, required=True, location='args') + args = parser.parse_args() + + return ExtensionService.get_code_based_extensions(args['module']) + + +api.add_resource(CodeBasedExtension, '/code-based-extensions') \ No newline at end of file diff --git a/api/core/__init__.py b/api/core/__init__.py index e69de29bb2d1d6..8c986fc8bd8afa 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -0,0 +1 @@ +import core.moderation.base \ No newline at end of file diff --git a/api/core/helper/auto_register.py b/api/core/helper/auto_register.py new file mode 100644 index 00000000000000..a908e8073655e3 --- /dev/null +++ b/api/core/helper/auto_register.py @@ -0,0 +1,17 @@ +# Desc: Metaclass for auto-registering subclasses of a class. + +class AutoRegisterMeta(type): + def __init__(cls, name, bases, attrs): + super(AutoRegisterMeta, cls).__init__(name, bases, attrs) + if not hasattr(cls, 'subclasses'): + cls.subclasses = {} + else: + register_name = getattr(cls, 'register_name', name) + cls.subclasses[register_name] = cls + +class AutoRegisterBase(metaclass=AutoRegisterMeta): + @classmethod + def create_instance(cls, subclass_name, *args, **kwargs): + if subclass_name not in cls.subclasses: + raise ValueError(f"No register_name with name '{subclass_name}' found") + return cls.subclasses[subclass_name](*args, **kwargs) diff --git a/api/core/helper/extensible.py b/api/core/helper/extensible.py new file mode 100644 index 00000000000000..7c3220cf0fe0b5 --- /dev/null +++ b/api/core/helper/extensible.py @@ -0,0 +1,34 @@ +import json +import os +import copy + +class Extensible: + __extensions = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.register() + + @classmethod + def register(cls): + subclass_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + subclass_dir_path = os.path.dirname(subclass_path) + parent_folder_name = os.path.basename(os.path.dirname(subclass_dir_path)) + + json_path = os.path.join(subclass_dir_path, 'schema.json') + json_data = {} + if os.path.exists(json_path): + with open(json_path, 'r') as f: + json_data = json.load(f) + + if parent_folder_name not in cls.__extensions: + cls.__extensions[parent_folder_name] = { + "module": parent_folder_name, + "data": [] + } + + cls.__extensions[parent_folder_name]["data"].append(json_data) + + @classmethod + def get_extensions(cls) -> dict: + return copy.deepcopy(cls.__extensions) \ No newline at end of file diff --git a/api/core/moderation/__init__.py b/api/core/moderation/__init__.py new file mode 100644 index 00000000000000..727a8a1e99897c --- /dev/null +++ b/api/core/moderation/__init__.py @@ -0,0 +1,4 @@ +from core.moderation.openai.openai import OpenAIModeration +from core.moderation.keywords.keywords import KeywordsModeration +from core.moderation.api_based.api_based import ApiBasedModeration +from core.moderation.cloud_service.cloud_service import CloudServiceModeration \ No newline at end of file diff --git a/api/core/moderation/api_based/__init__.py b/api/core/moderation/api_based/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/api_based/api_based.py b/api/core/moderation/api_based/api_based.py new file mode 100644 index 00000000000000..c5cae1c85da7fa --- /dev/null +++ b/api/core/moderation/api_based/api_based.py @@ -0,0 +1,13 @@ +from core.moderation.base import BaseModeration + + +class ApiBasedModeration(BaseModeration): + register_name = "api_based" + + @classmethod + def validate_config(self, config: dict) -> None: + api_based_extension_id = config.get("api_based_extension_id") + if not api_based_extension_id: + raise ValueError("api_based_extension_id is required") + + self._validate_inputs_and_outputs_config(config, False) \ No newline at end of file diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py new file mode 100644 index 00000000000000..ecf9d257fe5cb9 --- /dev/null +++ b/api/core/moderation/base.py @@ -0,0 +1,46 @@ +from abc import abstractclassmethod +from core.helper.auto_register import AutoRegisterBase + + +class BaseModeration(AutoRegisterBase): + + @abstractclassmethod + def validate_config(self, config: dict) -> None: + pass + + @abstractclassmethod + def moderation_for_inputs(self, config: dict): + pass + + @abstractclassmethod + def moderation_for_outputs(self, config: dict): + pass + + @classmethod + def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + # inputs_configs + inputs_configs = config.get("inputs_configs") + if not isinstance(inputs_configs, dict): + raise ValueError("inputs_configs must be a dict") + + # outputs_configs + outputs_configs = config.get("outputs_configs") + if not isinstance(outputs_configs, dict): + raise ValueError("outputs_configs must be a dict") + + inputs_configs_enabled = inputs_configs.get("enabled") + outputs_configs_enabled = outputs_configs.get("enabled") + if not inputs_configs_enabled and not outputs_configs_enabled: + raise ValueError("At least one of inputs_configs or outputs_configs must be enabled") + + # preset_response + if not is_preset_response_required: + return + + if inputs_configs_enabled and not inputs_configs.get("preset_response"): + raise ValueError("inputs_configs.preset_response is required") + + if outputs_configs_enabled and not outputs_configs.get("preset_response"): + raise ValueError("outputs_configs.preset_response is required") + + \ No newline at end of file diff --git a/api/core/moderation/cloud_service/__init__.py b/api/core/moderation/cloud_service/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/cloud_service/cloud_service.py b/api/core/moderation/cloud_service/cloud_service.py new file mode 100644 index 00000000000000..eaaa9c1eeca8a6 --- /dev/null +++ b/api/core/moderation/cloud_service/cloud_service.py @@ -0,0 +1,5 @@ +from core.moderation.base import BaseModeration +from core.helper.extensible import Extensible + +class CloudServiceModeration(BaseModeration, Extensible): + register_name = "cloud_service" \ No newline at end of file diff --git a/api/core/moderation/cloud_service/schema.json b/api/core/moderation/cloud_service/schema.json new file mode 100644 index 00000000000000..b9491de91dedc0 --- /dev/null +++ b/api/core/moderation/cloud_service/schema.json @@ -0,0 +1,51 @@ +{ + "name": "cloud_service", + "label": { + "en-US": "Cloud Service", + "zh-Hans": "云服务" + }, + "form_schema": [ + { + "select": { + "label": { + "en-US": "Cloud Provider", + "zh-Hans": "云计算厂商" + }, + "variable": "cloud_provider", + "required": true, + "options": [ + "腾讯云", + "阿里云", + "AWS" + ], + "default": "", + "placeholder": "" + } + }, + { + "text-input": { + "label": { + "en-US": "API Endpoint", + "zh-Hans": "API Endpoint" + }, + "variable": "api_endpoint", + "required": true, + "max_length": 100, + "default": "", + "placeholder": "" + } + }, + { + "paragraph": { + "label": { + "en-US": "API Key", + "zh-Hans": "API Key" + }, + "variable": "api_keys", + "required": true, + "default": "", + "placeholder": "" + } + } + ] +} \ No newline at end of file diff --git a/api/core/moderation/keywords/__init__.py b/api/core/moderation/keywords/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py new file mode 100644 index 00000000000000..4d58c1f5f62a86 --- /dev/null +++ b/api/core/moderation/keywords/keywords.py @@ -0,0 +1,13 @@ +from core.moderation.base import BaseModeration + +class KeywordsModeration(BaseModeration): + register_name = "keywords" + + @classmethod + def validate_config(self, config): + keywords = config.get("keywords") + if not keywords: + raise ValueError("keywords is required") + + self._validate_inputs_and_outputs_config(config, True) + \ No newline at end of file diff --git a/api/core/moderation/openai/__init__.py b/api/core/moderation/openai/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py new file mode 100644 index 00000000000000..b65748265da70d --- /dev/null +++ b/api/core/moderation/openai/openai.py @@ -0,0 +1,8 @@ +from core.moderation.base import BaseModeration + +class OpenAIModeration(BaseModeration): + register_name = "openai" + + @classmethod + def validate_config(self, config: dict): + self._validate_inputs_and_outputs_config(config, True) \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 79c1ed0ad6f663..a9c6db8d43746a 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -7,6 +7,7 @@ from core.model_providers.models.entity.model_params import ModelType, ModelMode from models.account import Account from services.dataset_service import DatasetService +from core.moderation.base import BaseModeration SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -153,33 +154,6 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: if not isinstance(config["more_like_this"]["enabled"], bool): raise ValueError("enabled in more_like_this must be of boolean type") - # sensitive_word_avoidance - if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: - config["sensitive_word_avoidance"] = { - "enabled": False - } - - if not isinstance(config["sensitive_word_avoidance"], dict): - raise ValueError("sensitive_word_avoidance must be of dict type") - - if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: - config["sensitive_word_avoidance"]["enabled"] = False - - if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool): - raise ValueError("enabled in sensitive_word_avoidance must be of boolean type") - - if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]: - config["sensitive_word_avoidance"]["words"] = "" - - if not isinstance(config["sensitive_word_avoidance"]["words"], str): - raise ValueError("words in sensitive_word_avoidance must be of string type") - - if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]: - config["sensitive_word_avoidance"]["canned_response"] = "" - - if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str): - raise ValueError("canned_response in sensitive_word_avoidance must be of string type") - # model if 'model' not in config: raise ValueError("model is required") @@ -339,6 +313,9 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: # advanced prompt validation AppModelConfigService.is_advanced_prompt_valid(config, mode) + # moderation validation + AppModelConfigService.is_moderation_valid(config) + # Filter out extra parameters filtered_config = { "opening_statement": config["opening_statement"], @@ -365,6 +342,27 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: } return filtered_config + + @staticmethod + def is_moderation_valid(config): + if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: + config["sensitive_word_avoidance"] = { + "enabled": False + } + + if not isinstance(config["sensitive_word_avoidance"], dict): + raise ValueError("sensitive_word_avoidance must be of dict type") + + if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: + config["sensitive_word_avoidance"]["enabled"] = False + + if not config["sensitive_word_avoidance"]["enabled"]: + return + + if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: + raise ValueError("sensitive_word_avoidance.type is required") + + BaseModeration.create_instance(type).validate_config(config["sensitive_word_avoidance"]["configs"]) @staticmethod def is_dataset_query_variable_valid(config: dict, mode: str) -> None: diff --git a/api/services/extension_service.py b/api/services/extension_service.py new file mode 100644 index 00000000000000..08ecf50efc84a2 --- /dev/null +++ b/api/services/extension_service.py @@ -0,0 +1,7 @@ +from core.helper.extensible import Extensible + +class ExtensionService: + + @classmethod + def get_code_based_extensions(cls, module: str) -> list[dict]: + return Extensible.get_extensions().get(module, []) \ No newline at end of file From 382d65c7ae877b56db88cac1869c0de3bf6a11b2 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 27 Oct 2023 10:23:29 +0800 Subject: [PATCH 02/57] update. --- api/services/app_model_config_service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index a9c6db8d43746a..cf3da7e31b8ebb 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -361,8 +361,11 @@ def is_moderation_valid(config): if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: raise ValueError("sensitive_word_avoidance.type is required") + + type = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["configs"] - BaseModeration.create_instance(type).validate_config(config["sensitive_word_avoidance"]["configs"]) + BaseModeration.create_instance(type).validate_config(config) @staticmethod def is_dataset_query_variable_valid(config: dict, mode: str) -> None: From 3c36ce7b2a80d4de79ffe902c98107f3cdb3132c Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 27 Oct 2023 16:11:01 +0800 Subject: [PATCH 03/57] api_based_extension. --- api/controllers/console/extension.py | 97 ++++++++++++++++++- api/fields/api_based_extension_fields.py | 11 +++ .../968fff4c0ab9_add_api_based_extension.py | 45 +++++++++ api/models/api_based_extension.py | 17 ++++ api/services/api_based_extension_service.py | 66 +++++++++++++ api/services/code_based_extension_service.py | 7 ++ api/services/extension_service.py | 7 -- 7 files changed, 238 insertions(+), 12 deletions(-) create mode 100644 api/fields/api_based_extension_fields.py create mode 100644 api/migrations/versions/968fff4c0ab9_add_api_based_extension.py create mode 100644 api/models/api_based_extension.py create mode 100644 api/services/api_based_extension_service.py create mode 100644 api/services/code_based_extension_service.py delete mode 100644 api/services/extension_service.py diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 0b5ead214bf508..4a8439068a5dfa 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,13 +1,17 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse, marshal_with +from flask_login import current_user from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from libs.login import login_required -from services.extension_service import ExtensionService +from models.api_based_extension import APIBasedExtension +from fields.api_based_extension_fields import api_based_extension_fields +from services.code_based_extension_service import CodeBasedExtensionService +from services.api_based_extension_service import APIBasedExtensionService -class CodeBasedExtension(Resource): +class CodeBasedExtensionAPI(Resource): @setup_required @login_required @@ -17,7 +21,90 @@ def get(self): parser.add_argument('module', type=str, required=True, location='args') args = parser.parse_args() - return ExtensionService.get_code_based_extensions(args['module']) + return CodeBasedExtensionService.get_code_based_extension(args['module']) + +class APIBasedExtensionAPI(Resource): -api.add_resource(CodeBasedExtension, '/code-based-extensions') \ No newline at end of file + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def get(self): + tenant_id = current_user.current_tenant_id + return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('api_endpoint', type=str, required=True, location='json') + parser.add_argument('api_key', type=str, required=True, location='json') + args = parser.parse_args() + + extension_data = APIBasedExtension( + tenant_id=current_user.current_tenant_id, + name=args['name'], + api_endpoint=args['api_endpoint'], + api_key=args['api_key'] + ) + + return APIBasedExtensionService.save(extension_data) + +class APIBasedExtensionDetailAPI(Resource): + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def get(self, id): + api_based_extension_id = str(id) + tenant_id = current_user.current_tenant_id + + return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def post(self, id): + api_based_extension_id = str(id) + tenant_id = current_user.current_tenant_id + + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('api_endpoint', type=str, required=True, location='json') + parser.add_argument('api_key', type=str, required=True, location='json') + args = parser.parse_args() + + extension_data_from_db.name = args['name'] + extension_data_from_db.api_endpoint = args['api_endpoint'] + + if args['api_key'] != '[__HIDDEN__]': + extension_data_from_db.api_key = args['api_key'] + + return APIBasedExtensionService.save(extension_data_from_db) + + @setup_required + @login_required + @account_initialization_required + def delete(self, id): + api_based_extension_id = str(id) + tenant_id = current_user.current_tenant_id + + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + + APIBasedExtensionService.delete(extension_data_from_db) + + return {'result': 'success'} + + +api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') + +api.add_resource(APIBasedExtensionAPI, '/api-based-extension') +api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') \ No newline at end of file diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py new file mode 100644 index 00000000000000..b3dfaf536b7d36 --- /dev/null +++ b/api/fields/api_based_extension_fields.py @@ -0,0 +1,11 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +api_based_extension_fields = { + 'id': fields.String, + 'name': fields.String, + 'api_endpoint': fields.String, + 'api_key': fields.String, + 'created_at': TimestampField +} \ No newline at end of file diff --git a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py new file mode 100644 index 00000000000000..8876ef71489433 --- /dev/null +++ b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py @@ -0,0 +1,45 @@ +"""add_api_based_extension + +Revision ID: 968fff4c0ab9 +Revises: b3a09c049e8e +Create Date: 2023-10-27 13:05:58.901858 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '968fff4c0ab9' +down_revision = 'b3a09c049e8e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.create_table('api_based_extensions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('api_endpoint', sa.String(length=255), nullable=False), + sa.Column('api_key', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_based_extension_pkey') + ) + with op.batch_alter_table('api_based_extensions', schema=None) as batch_op: + batch_op.create_index('api_based_extension_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('api_based_extensions', schema=None) as batch_op: + batch_op.drop_index('api_based_extension_tenant_idx') + + op.drop_table('api_based_extensions') + + # ### end Alembic commands ### diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py new file mode 100644 index 00000000000000..f31a404a5c9b6b --- /dev/null +++ b/api/models/api_based_extension.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db + +class APIBasedExtension(db.Model): + __tablename__ = 'api_based_extensions' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), + db.Index('api_based_extension_tenant_idx', 'tenant_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + api_endpoint = db.Column(db.String(255), nullable=False) + api_key = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) \ No newline at end of file diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py new file mode 100644 index 00000000000000..c6634dd27b6657 --- /dev/null +++ b/api/services/api_based_extension_service.py @@ -0,0 +1,66 @@ +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension + +class APIBasedExtensionService: + + @staticmethod + def get_all_by_tenant_id(tenant_id: str): + return db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=tenant_id) \ + .order_by(APIBasedExtension.created_at.desc()) \ + .all() + + @staticmethod + def save(extension_data: APIBasedExtension) -> APIBasedExtension: + # name + if not extension_data.name: + raise ValueError("name must not be empty") + + if not extension_data.id: + # case one: check new data, name must be unique + is_name_existed = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=extension_data.tenant_id) \ + .filter_by(name=extension_data.name) \ + .first() + + if is_name_existed: + raise ValueError("name must be unique, it is already existed") + else: + # case two: check existing data, name must be unique + is_name_existed = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=extension_data.tenant_id) \ + .filter_by(name=extension_data.name) \ + .filter(APIBasedExtension.id != extension_data.id) \ + .first() + + if is_name_existed: + raise ValueError("name must be unique, it is already existed") + + # api_endpoint + if not extension_data.api_endpoint: + raise ValueError("api_endpoint must not be empty") + + # api_key + if not extension_data.api_key: + raise ValueError("api_key must not be empty") + + db.session.add(extension_data) + db.session.commit() + return extension_data + + @staticmethod + def delete(extension_data: APIBasedExtension) -> None: + db.session.delete(extension_data) + db.session.commit() + + @staticmethod + def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + api_based_extension = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=tenant_id) \ + .filter_by(id=api_based_extension_id) \ + .first() + + if not api_based_extension: + raise ValueError("API based extension is not found") + + return api_based_extension \ No newline at end of file diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py new file mode 100644 index 00000000000000..0541b8e360470b --- /dev/null +++ b/api/services/code_based_extension_service.py @@ -0,0 +1,7 @@ +from core.helper.extensible import Extensible + +class CodeBasedExtensionService: + + @staticmethod + def get_code_based_extension(module: str) -> list[dict]: + return Extensible.get_extensions().get(module, []) \ No newline at end of file diff --git a/api/services/extension_service.py b/api/services/extension_service.py deleted file mode 100644 index 08ecf50efc84a2..00000000000000 --- a/api/services/extension_service.py +++ /dev/null @@ -1,7 +0,0 @@ -from core.helper.extensible import Extensible - -class ExtensionService: - - @classmethod - def get_code_based_extensions(cls, module: str) -> list[dict]: - return Extensible.get_extensions().get(module, []) \ No newline at end of file From b052a84838289902cb0dcebf5ab96afde919fed5 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Mon, 30 Oct 2023 12:42:31 +0800 Subject: [PATCH 04/57] refactor. --- api/controllers/console/extension.py | 7 ++++--- api/fields/api_based_extension_fields.py | 7 ++++++- .../versions/968fff4c0ab9_add_api_based_extension.py | 2 +- api/models/api_based_extension.py | 2 +- api/services/api_based_extension_service.py | 6 +++++- 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 4a8439068a5dfa..988ff6b54de752 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -52,7 +52,7 @@ def post(self): api_key=args['api_key'] ) - return APIBasedExtensionService.save(extension_data) + return APIBasedExtensionService.save(extension_data, True) class APIBasedExtensionDetailAPI(Resource): @@ -85,10 +85,11 @@ def post(self, id): extension_data_from_db.name = args['name'] extension_data_from_db.api_endpoint = args['api_endpoint'] + need_encrypt = False if args['api_key'] != '[__HIDDEN__]': - extension_data_from_db.api_key = args['api_key'] + need_encrypt = True - return APIBasedExtensionService.save(extension_data_from_db) + return APIBasedExtensionService.save(extension_data_from_db, need_encrypt) @setup_required @login_required diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index b3dfaf536b7d36..e9319029d3175f 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -2,10 +2,15 @@ from libs.helper import TimestampField + +class HiddenAPIKey(fields.Raw): + def output(self, key, obj): + return obj.api_key[:8] + '***' + obj.api_key[-8:] + api_based_extension_fields = { 'id': fields.String, 'name': fields.String, 'api_endpoint': fields.String, - 'api_key': fields.String, + 'api_key': HiddenAPIKey, 'created_at': TimestampField } \ No newline at end of file diff --git a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py index 8876ef71489433..57b28e707f3b2b 100644 --- a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py +++ b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py @@ -24,7 +24,7 @@ def upgrade(): sa.Column('tenant_id', postgresql.UUID(), nullable=False), sa.Column('name', sa.String(length=255), nullable=False), sa.Column('api_endpoint', sa.String(length=255), nullable=False), - sa.Column('api_key', sa.String(length=255), nullable=False), + sa.Column('api_key', sa.Text(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.PrimaryKeyConstraint('id', name='api_based_extension_pkey') ) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index f31a404a5c9b6b..f0bcbc7b5ae595 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -13,5 +13,5 @@ class APIBasedExtension(db.Model): tenant_id = db.Column(UUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) - api_key = db.Column(db.String(255), nullable=False) + api_key = db.Column(db.Text, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) \ No newline at end of file diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index c6634dd27b6657..4c42212f64ed94 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,5 +1,6 @@ from extensions.ext_database import db from models.api_based_extension import APIBasedExtension +from core.helper.encrypter import encrypt_token class APIBasedExtensionService: @@ -11,7 +12,7 @@ def get_all_by_tenant_id(tenant_id: str): .all() @staticmethod - def save(extension_data: APIBasedExtension) -> APIBasedExtension: + def save(extension_data: APIBasedExtension, need_encrypt: bool) -> APIBasedExtension: # name if not extension_data.name: raise ValueError("name must not be empty") @@ -44,6 +45,9 @@ def save(extension_data: APIBasedExtension) -> APIBasedExtension: if not extension_data.api_key: raise ValueError("api_key must not be empty") + if need_encrypt: + extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) + db.session.add(extension_data) db.session.commit() return extension_data From dc849bfc487e43eb67425f04c0e011c0c257c032 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Mon, 30 Oct 2023 17:29:41 +0800 Subject: [PATCH 05/57] update. --- api/core/moderation/cloud_service/schema.json | 69 +++++++++---------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/api/core/moderation/cloud_service/schema.json b/api/core/moderation/cloud_service/schema.json index b9491de91dedc0..b74d8564f77533 100644 --- a/api/core/moderation/cloud_service/schema.json +++ b/api/core/moderation/cloud_service/schema.json @@ -6,46 +6,43 @@ }, "form_schema": [ { - "select": { - "label": { - "en-US": "Cloud Provider", - "zh-Hans": "云计算厂商" - }, - "variable": "cloud_provider", - "required": true, - "options": [ - "腾讯云", - "阿里云", - "AWS" - ], - "default": "", - "placeholder": "" - } + "type": "select", + "label": { + "en-US": "Cloud Provider", + "zh-Hans": "云计算厂商" + }, + "variable": "cloud_provider", + "required": true, + "options": [ + "腾讯云", + "阿里云", + "AWS" + ], + "default": "", + "placeholder": "" }, { - "text-input": { - "label": { - "en-US": "API Endpoint", - "zh-Hans": "API Endpoint" - }, - "variable": "api_endpoint", - "required": true, - "max_length": 100, - "default": "", - "placeholder": "" - } + "type": "text-input", + "label": { + "en-US": "API Endpoint", + "zh-Hans": "API Endpoint" + }, + "variable": "api_endpoint", + "required": true, + "max_length": 100, + "default": "", + "placeholder": "" }, { - "paragraph": { - "label": { - "en-US": "API Key", - "zh-Hans": "API Key" - }, - "variable": "api_keys", - "required": true, - "default": "", - "placeholder": "" - } + "type": "paragraph", + "label": { + "en-US": "API Key", + "zh-Hans": "API Key" + }, + "variable": "api_keys", + "required": true, + "default": "", + "placeholder": "" } ] } \ No newline at end of file From ce8af8b8b3e8fc3a380ccec9771e0e44572ca132 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Mon, 30 Oct 2023 18:53:38 +0800 Subject: [PATCH 06/57] update. --- api/core/helper/auto_register.py | 17 ----------------- api/core/moderation/api_based/api_based.py | 2 +- api/core/moderation/base.py | 17 ++++++++++++++--- .../moderation/cloud_service/cloud_service.py | 2 +- api/core/moderation/keywords/keywords.py | 2 +- api/core/moderation/openai/openai.py | 2 +- 6 files changed, 18 insertions(+), 24 deletions(-) delete mode 100644 api/core/helper/auto_register.py diff --git a/api/core/helper/auto_register.py b/api/core/helper/auto_register.py deleted file mode 100644 index a908e8073655e3..00000000000000 --- a/api/core/helper/auto_register.py +++ /dev/null @@ -1,17 +0,0 @@ -# Desc: Metaclass for auto-registering subclasses of a class. - -class AutoRegisterMeta(type): - def __init__(cls, name, bases, attrs): - super(AutoRegisterMeta, cls).__init__(name, bases, attrs) - if not hasattr(cls, 'subclasses'): - cls.subclasses = {} - else: - register_name = getattr(cls, 'register_name', name) - cls.subclasses[register_name] = cls - -class AutoRegisterBase(metaclass=AutoRegisterMeta): - @classmethod - def create_instance(cls, subclass_name, *args, **kwargs): - if subclass_name not in cls.subclasses: - raise ValueError(f"No register_name with name '{subclass_name}' found") - return cls.subclasses[subclass_name](*args, **kwargs) diff --git a/api/core/moderation/api_based/api_based.py b/api/core/moderation/api_based/api_based.py index c5cae1c85da7fa..0e2146339c30ae 100644 --- a/api/core/moderation/api_based/api_based.py +++ b/api/core/moderation/api_based/api_based.py @@ -2,7 +2,7 @@ class ApiBasedModeration(BaseModeration): - register_name = "api_based" + type = "api_based" @classmethod def validate_config(self, config: dict) -> None: diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index ecf9d257fe5cb9..9ad9ae866c373b 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,8 +1,14 @@ from abc import abstractclassmethod -from core.helper.auto_register import AutoRegisterBase -class BaseModeration(AutoRegisterBase): +class BaseModeration(): + _subclasses = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + type = getattr(cls, 'type', None) + if type: + BaseModeration._subclasses[type] = cls @abstractclassmethod def validate_config(self, config: dict) -> None: @@ -43,4 +49,9 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r if outputs_configs_enabled and not outputs_configs.get("preset_response"): raise ValueError("outputs_configs.preset_response is required") - \ No newline at end of file + @staticmethod + def create_instance(type: str, *args, **kwargs): + if type in BaseModeration._subclasses: + return BaseModeration._subclasses[type](*args, **kwargs) + else: + raise ValueError(f"No type named {type} found.") \ No newline at end of file diff --git a/api/core/moderation/cloud_service/cloud_service.py b/api/core/moderation/cloud_service/cloud_service.py index eaaa9c1eeca8a6..763e7979783d35 100644 --- a/api/core/moderation/cloud_service/cloud_service.py +++ b/api/core/moderation/cloud_service/cloud_service.py @@ -2,4 +2,4 @@ from core.helper.extensible import Extensible class CloudServiceModeration(BaseModeration, Extensible): - register_name = "cloud_service" \ No newline at end of file + type = "cloud_service" \ No newline at end of file diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 4d58c1f5f62a86..d89381e165c822 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,7 +1,7 @@ from core.moderation.base import BaseModeration class KeywordsModeration(BaseModeration): - register_name = "keywords" + type = "keywords" @classmethod def validate_config(self, config): diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py index b65748265da70d..193f2c7f92cb42 100644 --- a/api/core/moderation/openai/openai.py +++ b/api/core/moderation/openai/openai.py @@ -1,7 +1,7 @@ from core.moderation.base import BaseModeration class OpenAIModeration(BaseModeration): - register_name = "openai" + type = "openai" @classmethod def validate_config(self, config: dict): From bfa2a07c50ef9fd63c693e156233b9330a6446d3 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Tue, 31 Oct 2023 11:44:45 +0800 Subject: [PATCH 07/57] update. --- api/core/moderation/api_based/api_based.py | 4 ++-- api/core/moderation/keywords/keywords.py | 4 ++-- api/core/moderation/openai/openai.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/moderation/api_based/api_based.py b/api/core/moderation/api_based/api_based.py index 0e2146339c30ae..d83d506d683dce 100644 --- a/api/core/moderation/api_based/api_based.py +++ b/api/core/moderation/api_based/api_based.py @@ -5,9 +5,9 @@ class ApiBasedModeration(BaseModeration): type = "api_based" @classmethod - def validate_config(self, config: dict) -> None: + def validate_config(cls, config: dict) -> None: api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: raise ValueError("api_based_extension_id is required") - self._validate_inputs_and_outputs_config(config, False) \ No newline at end of file + cls._validate_inputs_and_outputs_config(config, False) \ No newline at end of file diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index d89381e165c822..af51871a7701ab 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -4,10 +4,10 @@ class KeywordsModeration(BaseModeration): type = "keywords" @classmethod - def validate_config(self, config): + def validate_config(cls, config): keywords = config.get("keywords") if not keywords: raise ValueError("keywords is required") - self._validate_inputs_and_outputs_config(config, True) + cls._validate_inputs_and_outputs_config(config, True) \ No newline at end of file diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py index 193f2c7f92cb42..5d06ecae827acb 100644 --- a/api/core/moderation/openai/openai.py +++ b/api/core/moderation/openai/openai.py @@ -4,5 +4,5 @@ class OpenAIModeration(BaseModeration): type = "openai" @classmethod - def validate_config(self, config: dict): - self._validate_inputs_and_outputs_config(config, True) \ No newline at end of file + def validate_config(cls, config: dict): + cls._validate_inputs_and_outputs_config(config, True) \ No newline at end of file From 2819b08bd7aecddc1a5d0b26ffdfc9111fe98bcb Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 31 Oct 2023 21:44:09 +0800 Subject: [PATCH 08/57] feat: refactor code-based extension --- api/app.py | 3 +- api/core/extension/__init__.py | 0 api/core/extension/extensible.py | 84 +++++++++++++++++++ api/core/extension/extension_factory.py | 37 ++++++++ api/core/external_data_tool/__init__.py | 0 .../external_data_tool/api_based/__builtin__ | 0 .../external_data_tool/api_based/__init__.py | 0 .../external_data_tool/api_based/api_based.py | 13 +++ api/core/external_data_tool/base.py | 5 ++ api/core/helper/extensible.py | 34 -------- api/core/moderation/api_based/api_based.py | 4 +- api/core/moderation/base.py | 13 +-- .../moderation/cloud_service/cloud_service.py | 6 +- api/core/moderation/cloud_service/schema.json | 1 - api/core/moderation/keywords/keywords.py | 4 +- api/core/moderation/openai/openai.py | 4 +- api/extensions/ext_code_based_extension.py | 8 ++ api/services/app_model_config_service.py | 4 +- api/services/code_based_extension_service.py | 2 +- 19 files changed, 165 insertions(+), 57 deletions(-) create mode 100644 api/core/extension/__init__.py create mode 100644 api/core/extension/extensible.py create mode 100644 api/core/extension/extension_factory.py create mode 100644 api/core/external_data_tool/__init__.py create mode 100644 api/core/external_data_tool/api_based/__builtin__ create mode 100644 api/core/external_data_tool/api_based/__init__.py create mode 100644 api/core/external_data_tool/api_based/api_based.py create mode 100644 api/core/external_data_tool/base.py delete mode 100644 api/core/helper/extensible.py create mode 100644 api/extensions/ext_code_based_extension.py diff --git a/api/app.py b/api/app.py index ef32ad44e1f70c..bc0d25224e8b04 100644 --- a/api/app.py +++ b/api/app.py @@ -19,7 +19,7 @@ from core.model_providers.providers import hosted from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ - ext_database, ext_storage, ext_mail, ext_stripe + ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension from extensions.ext_database import db from extensions.ext_login import login_manager @@ -79,6 +79,7 @@ def create_app(test_config=None) -> Flask: def initialize_extensions(app): # Since the application instance is now created, pass it to each Flask # extension instance to bind it to the Flask application instance (app) + ext_code_based_extension.init() ext_database.init_app(app) ext_migrate.init(app, db) ext_redis.init_app(app) diff --git a/api/core/extension/__init__.py b/api/core/extension/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py new file mode 100644 index 00000000000000..2c67fc107cce34 --- /dev/null +++ b/api/core/extension/extensible.py @@ -0,0 +1,84 @@ +import importlib.util +import json +import logging +import os +from typing import Any + +from pydantic import BaseModel + + +class ModuleExtension(BaseModel): + extension_class: Any + name: str + label: dict = {} + form_schema: list = [] + builtin: bool = True + + +class Extensible: + @classmethod + def scan_extensions(cls): + extensions = {} + + # get the path of the current class + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_dir_path = os.path.dirname(current_path) + + # traverse subdirectories + for subdir_name in os.listdir(current_dir_path): + if subdir_name.startswith('__'): + continue + + subdir_path = os.path.join(current_dir_path, subdir_name) + extension_name = subdir_name + if os.path.isdir(subdir_path): + file_names = os.listdir(subdir_path) + + # is builtin extension, builtin extension + # in the front-end page and business logic, there are special treatments. + builtin = False + if '__builtin__' in file_names: + builtin = True + + if (extension_name + '.py') not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Dynamic loading {subdir_name}.py file and find the subclass of Extensible + py_path = os.path.join(subdir_path, extension_name + '.py') + spec = importlib.util.spec_from_file_location(extension_name, py_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + extension_class = None + for name, obj in vars(mod).items(): + if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: + extension_class = obj + break + + if not extension_class: + logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + continue + + json_data = {} + if not builtin: + if 'schema.json' not in file_names: + logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") + continue + + json_path = os.path.join(subdir_path, 'schema.json') + json_data = {} + if os.path.exists(json_path): + builtin = False + with open(json_path, 'r') as f: + json_data = json.load(f) + + extensions[extension_name] = ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get('label', {}), + form_schema=json_data.get('form_schema', []), + builtin=builtin + ) + + return extensions diff --git a/api/core/extension/extension_factory.py b/api/core/extension/extension_factory.py new file mode 100644 index 00000000000000..767eff2efc909c --- /dev/null +++ b/api/core/extension/extension_factory.py @@ -0,0 +1,37 @@ +from core.extension.extensible import ModuleExtension +from core.external_data_tool.base import ExternalDataTool +from core.moderation.base import Moderation + + +class ExtensionFactory: + __module_extensions: dict[str, dict[str, ModuleExtension]] = {} + + module_classes = { + 'moderation': Moderation, + 'external_data_tool': ExternalDataTool + } + + def init(self): + for module, module_class in self.module_classes.items(): + self.__module_extensions[module] = module_class.scan_extensions() + + def module_extensions(self, module: str) -> list[ModuleExtension]: + module_extensions = self.__module_extensions.get(module) + + if not module_extensions: + raise ValueError(f"Extension Module {module} not found") + + return list(module_extensions.values()) + + def module_extension(self, module: str, extension_name: str) -> ModuleExtension: + module_extensions = self.__module_extensions.get(module) + + if not module_extensions: + raise ValueError(f"Extension Module {module} not found") + + module_extension = module_extensions.get(extension_name) + + if not module_extension: + raise ValueError(f"Extension {extension_name} not found") + + return module_extension diff --git a/api/core/external_data_tool/__init__.py b/api/core/external_data_tool/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/external_data_tool/api_based/__builtin__ b/api/core/external_data_tool/api_based/__builtin__ new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/external_data_tool/api_based/__init__.py b/api/core/external_data_tool/api_based/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/external_data_tool/api_based/api_based.py b/api/core/external_data_tool/api_based/api_based.py new file mode 100644 index 00000000000000..0e98b4417c11a7 --- /dev/null +++ b/api/core/external_data_tool/api_based/api_based.py @@ -0,0 +1,13 @@ +from core.external_data_tool.base import ExternalDataTool + + +class ApiBasedExternalDataTool(ExternalDataTool): + type = "api_based" + + @classmethod + def validate_config(self, config: dict) -> None: + api_based_extension_id = config.get("api_based_extension_id") + if not api_based_extension_id: + raise ValueError("api_based_extension_id is required") + + # todo diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py new file mode 100644 index 00000000000000..beb2ae85cd7fd2 --- /dev/null +++ b/api/core/external_data_tool/base.py @@ -0,0 +1,5 @@ +from core.extension.extensible import Extensible + + +class ExternalDataTool(Extensible): + pass diff --git a/api/core/helper/extensible.py b/api/core/helper/extensible.py deleted file mode 100644 index 7c3220cf0fe0b5..00000000000000 --- a/api/core/helper/extensible.py +++ /dev/null @@ -1,34 +0,0 @@ -import json -import os -import copy - -class Extensible: - __extensions = {} - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - cls.register() - - @classmethod - def register(cls): - subclass_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') - subclass_dir_path = os.path.dirname(subclass_path) - parent_folder_name = os.path.basename(os.path.dirname(subclass_dir_path)) - - json_path = os.path.join(subclass_dir_path, 'schema.json') - json_data = {} - if os.path.exists(json_path): - with open(json_path, 'r') as f: - json_data = json.load(f) - - if parent_folder_name not in cls.__extensions: - cls.__extensions[parent_folder_name] = { - "module": parent_folder_name, - "data": [] - } - - cls.__extensions[parent_folder_name]["data"].append(json_data) - - @classmethod - def get_extensions(cls) -> dict: - return copy.deepcopy(cls.__extensions) \ No newline at end of file diff --git a/api/core/moderation/api_based/api_based.py b/api/core/moderation/api_based/api_based.py index d83d506d683dce..f11d0f5dfdd7a7 100644 --- a/api/core/moderation/api_based/api_based.py +++ b/api/core/moderation/api_based/api_based.py @@ -1,7 +1,7 @@ -from core.moderation.base import BaseModeration +from core.moderation.base import Moderation -class ApiBasedModeration(BaseModeration): +class ApiBasedModeration(Moderation): type = "api_based" @classmethod diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 9ad9ae866c373b..0b4dd7455c4883 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,14 +1,9 @@ from abc import abstractclassmethod +from core.extension.extensible import Extensible -class BaseModeration(): - _subclasses = {} - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - type = getattr(cls, 'type', None) - if type: - BaseModeration._subclasses[type] = cls +class Moderation(Extensible): @abstractclassmethod def validate_config(self, config: dict) -> None: @@ -51,7 +46,7 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r @staticmethod def create_instance(type: str, *args, **kwargs): - if type in BaseModeration._subclasses: - return BaseModeration._subclasses[type](*args, **kwargs) + if type in Moderation._subclasses: + return Moderation._subclasses[type](*args, **kwargs) else: raise ValueError(f"No type named {type} found.") \ No newline at end of file diff --git a/api/core/moderation/cloud_service/cloud_service.py b/api/core/moderation/cloud_service/cloud_service.py index 763e7979783d35..c404ab27a92394 100644 --- a/api/core/moderation/cloud_service/cloud_service.py +++ b/api/core/moderation/cloud_service/cloud_service.py @@ -1,5 +1,5 @@ -from core.moderation.base import BaseModeration -from core.helper.extensible import Extensible +from core.moderation.base import Moderation +from core.extension.extensible import Extensible -class CloudServiceModeration(BaseModeration, Extensible): +class CloudServiceModeration(Moderation): type = "cloud_service" \ No newline at end of file diff --git a/api/core/moderation/cloud_service/schema.json b/api/core/moderation/cloud_service/schema.json index b74d8564f77533..3e05822ff60cd2 100644 --- a/api/core/moderation/cloud_service/schema.json +++ b/api/core/moderation/cloud_service/schema.json @@ -1,5 +1,4 @@ { - "name": "cloud_service", "label": { "en-US": "Cloud Service", "zh-Hans": "云服务" diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index af51871a7701ab..7ce4f9542c9db5 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,6 +1,6 @@ -from core.moderation.base import BaseModeration +from core.moderation.base import Moderation -class KeywordsModeration(BaseModeration): +class KeywordsModeration(Moderation): type = "keywords" @classmethod diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py index 5d06ecae827acb..d21b541bc777f3 100644 --- a/api/core/moderation/openai/openai.py +++ b/api/core/moderation/openai/openai.py @@ -1,6 +1,6 @@ -from core.moderation.base import BaseModeration +from core.moderation.base import Moderation -class OpenAIModeration(BaseModeration): +class OpenAIModeration(Moderation): type = "openai" @classmethod diff --git a/api/extensions/ext_code_based_extension.py b/api/extensions/ext_code_based_extension.py new file mode 100644 index 00000000000000..118cc2fb5789f3 --- /dev/null +++ b/api/extensions/ext_code_based_extension.py @@ -0,0 +1,8 @@ +from core.extension.extension_factory import ExtensionFactory + + +def init(): + code_based_extension_factory.init() + + +code_based_extension_factory = ExtensionFactory() diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index cf3da7e31b8ebb..7bba5cf14a839d 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -7,7 +7,7 @@ from core.model_providers.models.entity.model_params import ModelType, ModelMode from models.account import Account from services.dataset_service import DatasetService -from core.moderation.base import BaseModeration +from core.moderation.base import Moderation SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -365,7 +365,7 @@ def is_moderation_valid(config): type = config["sensitive_word_avoidance"]["type"] config = config["sensitive_word_avoidance"]["configs"] - BaseModeration.create_instance(type).validate_config(config) + Moderation.create_instance(type).validate_config(config) @staticmethod def is_dataset_query_variable_valid(config: dict, mode: str) -> None: diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 0541b8e360470b..b9737eb6272951 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -1,4 +1,4 @@ -from core.helper.extensible import Extensible +from core.extension.extensible import Extensible class CodeBasedExtensionService: From a21a950ce49f38c76804c85392e51755fa9a3ec6 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 31 Oct 2023 22:45:26 +0800 Subject: [PATCH 09/57] feat: refactor. --- api/core/extension/extensible.py | 6 ++++++ .../{extension_factory.py => extension.py} | 18 +++++++++++------- api/core/moderation/base.py | 7 ------- api/core/moderation/moderation_factory.py | 9 +++++++++ api/extensions/ext_code_based_extension.py | 6 +++--- api/services/app_model_config_service.py | 6 +++--- 6 files changed, 32 insertions(+), 20 deletions(-) rename api/core/extension/{extension_factory.py => extension.py} (57%) create mode 100644 api/core/moderation/moderation_factory.py diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 2c67fc107cce34..d6a3822f3e147c 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,3 +1,4 @@ +import enum import importlib.util import json import logging @@ -7,6 +8,11 @@ from pydantic import BaseModel +class ExtensionModule(enum.Enum): + MODERATION = 'moderation' + EXTERNAL_DATA_TOOL = 'external_data_tool' + + class ModuleExtension(BaseModel): extension_class: Any name: str diff --git a/api/core/extension/extension_factory.py b/api/core/extension/extension.py similarity index 57% rename from api/core/extension/extension_factory.py rename to api/core/extension/extension.py index 767eff2efc909c..494f46d017cc3b 100644 --- a/api/core/extension/extension_factory.py +++ b/api/core/extension/extension.py @@ -1,19 +1,19 @@ -from core.extension.extensible import ModuleExtension +from core.extension.extensible import ModuleExtension, ExtensionModule from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation -class ExtensionFactory: +class Extension: __module_extensions: dict[str, dict[str, ModuleExtension]] = {} module_classes = { - 'moderation': Moderation, - 'external_data_tool': ExternalDataTool + ExtensionModule.MODERATION: Moderation, + ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool } def init(self): for module, module_class in self.module_classes.items(): - self.__module_extensions[module] = module_class.scan_extensions() + self.__module_extensions[module.value] = module_class.scan_extensions() def module_extensions(self, module: str) -> list[ModuleExtension]: module_extensions = self.__module_extensions.get(module) @@ -23,8 +23,8 @@ def module_extensions(self, module: str) -> list[ModuleExtension]: return list(module_extensions.values()) - def module_extension(self, module: str, extension_name: str) -> ModuleExtension: - module_extensions = self.__module_extensions.get(module) + def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension: + module_extensions = self.__module_extensions.get(module.value) if not module_extensions: raise ValueError(f"Extension Module {module} not found") @@ -35,3 +35,7 @@ def module_extension(self, module: str, extension_name: str) -> ModuleExtension: raise ValueError(f"Extension {extension_name} not found") return module_extension + + def extension_class(self, module: ExtensionModule, extension_name: str) -> type: + module_extension = self.module_extension(module, extension_name) + return module_extension.extension_class diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 0b4dd7455c4883..60f2466840c197 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -43,10 +43,3 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r if outputs_configs_enabled and not outputs_configs.get("preset_response"): raise ValueError("outputs_configs.preset_response is required") - - @staticmethod - def create_instance(type: str, *args, **kwargs): - if type in Moderation._subclasses: - return Moderation._subclasses[type](*args, **kwargs) - else: - raise ValueError(f"No type named {type} found.") \ No newline at end of file diff --git a/api/core/moderation/moderation_factory.py b/api/core/moderation/moderation_factory.py new file mode 100644 index 00000000000000..318681395cbb4b --- /dev/null +++ b/api/core/moderation/moderation_factory.py @@ -0,0 +1,9 @@ +from core.extension.extensible import ExtensionModule +from core.moderation.base import Moderation +from extensions.ext_code_based_extension import code_based_extension + + +class ModerationFactory: + @staticmethod + def get(name: str) -> type[Moderation]: + return code_based_extension.extension_class(ExtensionModule.MODERATION, name) diff --git a/api/extensions/ext_code_based_extension.py b/api/extensions/ext_code_based_extension.py index 118cc2fb5789f3..a8ae733aa69927 100644 --- a/api/extensions/ext_code_based_extension.py +++ b/api/extensions/ext_code_based_extension.py @@ -1,8 +1,8 @@ -from core.extension.extension_factory import ExtensionFactory +from core.extension.extension import Extension def init(): - code_based_extension_factory.init() + code_based_extension.init() -code_based_extension_factory = ExtensionFactory() +code_based_extension = Extension() diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 7bba5cf14a839d..d5b907a699a10e 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,13 +1,13 @@ import re import uuid +from core.moderation.moderation_factory import ModerationFactory from core.prompt.prompt_transform import AppMode from core.agent.agent_executor import PlanningStrategy from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.models.entity.model_params import ModelType, ModelMode from models.account import Account from services.dataset_service import DatasetService -from core.moderation.base import Moderation SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -365,8 +365,8 @@ def is_moderation_valid(config): type = config["sensitive_word_avoidance"]["type"] config = config["sensitive_word_avoidance"]["configs"] - Moderation.create_instance(type).validate_config(config) - + ModerationFactory.get(type).validate_config(config) + @staticmethod def is_dataset_query_variable_valid(config: dict, mode: str) -> None: # Only check when mode is completion From fb2bc44d3303cb5f25b32c48ae1e15ae799b978b Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 31 Oct 2023 22:52:14 +0800 Subject: [PATCH 10/57] feat: refactor moderation factory --- api/core/moderation/moderation_factory.py | 10 ++++++---- api/services/app_model_config_service.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/api/core/moderation/moderation_factory.py b/api/core/moderation/moderation_factory.py index 318681395cbb4b..6b22ea42b90afb 100644 --- a/api/core/moderation/moderation_factory.py +++ b/api/core/moderation/moderation_factory.py @@ -1,9 +1,11 @@ from core.extension.extensible import ExtensionModule -from core.moderation.base import Moderation from extensions.ext_code_based_extension import code_based_extension class ModerationFactory: - @staticmethod - def get(name: str) -> type[Moderation]: - return code_based_extension.extension_class(ExtensionModule.MODERATION, name) + + def __init__(self, name: str): + self.__moderation_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + + def validate_config(self, config: dict) -> None: + self.__moderation_class.validate_config(config) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index d5b907a699a10e..5087e03057c345 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -365,7 +365,7 @@ def is_moderation_valid(config): type = config["sensitive_word_avoidance"]["type"] config = config["sensitive_word_avoidance"]["configs"] - ModerationFactory.get(type).validate_config(config) + ModerationFactory(type).validate_config(config) @staticmethod def is_dataset_query_variable_valid(config: dict, mode: str) -> None: From f1905e0bde92f34e50c9656e010ae0c2e66bffff Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 31 Oct 2023 23:18:55 +0800 Subject: [PATCH 11/57] feat: sort builtin extensions --- api/controllers/console/extension.py | 5 ++++- api/core/extension/extensible.py | 17 ++++++++++++++--- .../external_data_tool/api_based/__builtin__ | 1 + api/core/moderation/__init__.py | 4 ---- api/core/moderation/api_based/__builtin__ | 1 + api/core/moderation/keywords/__builtin__ | 1 + api/core/moderation/openai/__builtin__ | 1 + api/services/code_based_extension_service.py | 11 +++++++++-- 8 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 api/core/moderation/api_based/__builtin__ create mode 100644 api/core/moderation/keywords/__builtin__ create mode 100644 api/core/moderation/openai/__builtin__ diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 988ff6b54de752..30d38f6bf466f3 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -21,7 +21,10 @@ def get(self): parser.add_argument('module', type=str, required=True, location='args') args = parser.parse_args() - return CodeBasedExtensionService.get_code_based_extension(args['module']) + return { + 'module': args['module'], + 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) + } class APIBasedExtensionAPI(Resource): diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index d6a3822f3e147c..c8fd744858a643 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -3,6 +3,7 @@ import json import logging import os +from collections import OrderedDict from typing import Any from pydantic import BaseModel @@ -19,6 +20,7 @@ class ModuleExtension(BaseModel): label: dict = {} form_schema: list = [] builtin: bool = True + position: int = None class Extensible: @@ -43,9 +45,15 @@ def scan_extensions(cls): # is builtin extension, builtin extension # in the front-end page and business logic, there are special treatments. builtin = False + position = None if '__builtin__' in file_names: builtin = True + builtin_file_path = os.path.join(subdir_path, '__builtin__') + if os.path.exists(builtin_file_path): + with open(builtin_file_path, 'r') as f: + position = int(f.read().strip()) + if (extension_name + '.py') not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") continue @@ -75,7 +83,6 @@ def scan_extensions(cls): json_path = os.path.join(subdir_path, 'schema.json') json_data = {} if os.path.exists(json_path): - builtin = False with open(json_path, 'r') as f: json_data = json.load(f) @@ -84,7 +91,11 @@ def scan_extensions(cls): name=extension_name, label=json_data.get('label', {}), form_schema=json_data.get('form_schema', []), - builtin=builtin + builtin=builtin, + position=position ) - return extensions + sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) + sorted_extensions = OrderedDict(sorted_items) + + return sorted_extensions diff --git a/api/core/external_data_tool/api_based/__builtin__ b/api/core/external_data_tool/api_based/__builtin__ index e69de29bb2d1d6..56a6051ca2b02b 100644 --- a/api/core/external_data_tool/api_based/__builtin__ +++ b/api/core/external_data_tool/api_based/__builtin__ @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/api/core/moderation/__init__.py b/api/core/moderation/__init__.py index 727a8a1e99897c..e69de29bb2d1d6 100644 --- a/api/core/moderation/__init__.py +++ b/api/core/moderation/__init__.py @@ -1,4 +0,0 @@ -from core.moderation.openai.openai import OpenAIModeration -from core.moderation.keywords.keywords import KeywordsModeration -from core.moderation.api_based.api_based import ApiBasedModeration -from core.moderation.cloud_service.cloud_service import CloudServiceModeration \ No newline at end of file diff --git a/api/core/moderation/api_based/__builtin__ b/api/core/moderation/api_based/__builtin__ new file mode 100644 index 00000000000000..e440e5c8425869 --- /dev/null +++ b/api/core/moderation/api_based/__builtin__ @@ -0,0 +1 @@ +3 \ No newline at end of file diff --git a/api/core/moderation/keywords/__builtin__ b/api/core/moderation/keywords/__builtin__ new file mode 100644 index 00000000000000..d8263ee9860594 --- /dev/null +++ b/api/core/moderation/keywords/__builtin__ @@ -0,0 +1 @@ +2 \ No newline at end of file diff --git a/api/core/moderation/openai/__builtin__ b/api/core/moderation/openai/__builtin__ new file mode 100644 index 00000000000000..56a6051ca2b02b --- /dev/null +++ b/api/core/moderation/openai/__builtin__ @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index b9737eb6272951..ea804e590ccf8f 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -1,7 +1,14 @@ -from core.extension.extensible import Extensible +from extensions.ext_code_based_extension import code_based_extension + class CodeBasedExtensionService: @staticmethod def get_code_based_extension(module: str) -> list[dict]: - return Extensible.get_extensions().get(module, []) \ No newline at end of file + module_extensions = code_based_extension.module_extensions(module) + return [{ + 'name': module_extension.name, + 'label': module_extension.label, + 'form_schema': module_extension.form_schema, + 'builtin': module_extension.builtin + } for module_extension in module_extensions] From f4497bd3eda4d3118633467ec0c6be2d9ec1dff7 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 14:43:16 +0800 Subject: [PATCH 12/57] feat: refactor extensions --- .../api_based_extension_requestor.py | 2 + api/core/extension/extensible.py | 31 ++++++++--- .../external_data_tool/api_based/api_based.py | 52 +++++++++++++++++-- api/core/external_data_tool/base.py | 40 ++++++++++++-- .../external_data_tool_factory.py | 12 +++++ api/core/moderation/api_based/api_based.py | 36 +++++++++++-- api/core/moderation/base.py | 41 ++++++++++----- .../moderation/cloud_service/cloud_service.py | 12 ++++- api/core/moderation/keywords/keywords.py | 27 ++++++++-- api/core/moderation/moderation_factory.py | 7 +-- api/core/moderation/openai/openai.py | 23 ++++++-- 11 files changed, 242 insertions(+), 41 deletions(-) create mode 100644 api/core/extension/api_based_extension_requestor.py create mode 100644 api/core/external_data_tool/external_data_tool_factory.py diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py new file mode 100644 index 00000000000000..0d80d15368d41c --- /dev/null +++ b/api/core/extension/api_based_extension_requestor.py @@ -0,0 +1,2 @@ +class APIBasedExtensionRequestor: + pass diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index c8fd744858a643..46c9f776b6a34a 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -4,10 +4,12 @@ import logging import os from collections import OrderedDict -from typing import Any +from typing import Any, Optional from pydantic import BaseModel +from extensions.ext_code_based_extension import code_based_extension + class ExtensionModule(enum.Enum): MODERATION = 'moderation' @@ -17,13 +19,30 @@ class ExtensionModule(enum.Enum): class ModuleExtension(BaseModel): extension_class: Any name: str - label: dict = {} - form_schema: list = [] + label: Optional[dict] = None + form_schema: Optional[list] = None builtin: bool = True - position: int = None + position: Optional[int] = None class Extensible: + module: ExtensionModule + + name: str + tenant_id: str + config: Optional[dict] = None + + def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + self.tenant_id = tenant_id + self.config = config + + @classmethod + def validate_form_schema(cls, config: dict) -> None: + module_extension = code_based_extension.module_extension(cls.module, cls.name) + form_schema = module_extension.form_schema + + # TODO validate form_schema + @classmethod def scan_extensions(cls): extensions = {} @@ -89,8 +108,8 @@ def scan_extensions(cls): extensions[extension_name] = ModuleExtension( extension_class=extension_class, name=extension_name, - label=json_data.get('label', {}), - form_schema=json_data.get('form_schema', []), + label=json_data.get('label'), + form_schema=json_data.get('form_schema'), builtin=builtin, position=position ) diff --git a/api/core/external_data_tool/api_based/api_based.py b/api/core/external_data_tool/api_based/api_based.py index 0e98b4417c11a7..6cf6a9e6918cce 100644 --- a/api/core/external_data_tool/api_based/api_based.py +++ b/api/core/external_data_tool/api_based/api_based.py @@ -1,13 +1,57 @@ +from typing import Optional + from core.external_data_tool.base import ExternalDataTool +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension class ApiBasedExternalDataTool(ExternalDataTool): - type = "api_based" + name: str = "api" @classmethod - def validate_config(self, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + super().validate_config(tenant_id, config) + + # own validation logic api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: raise ValueError("api_based_extension_id is required") - - # todo + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("api_based_extension_id is invalid") + + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + # get params from config + api_based_extension_id = self.config.get("api_based_extension_id") + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == self.tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("api_based_extension_id is invalid") + + # todo request api + diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index beb2ae85cd7fd2..288c2beae3034e 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,5 +1,39 @@ -from core.extension.extensible import Extensible +from abc import abstractmethod, ABC +from typing import Optional +from core.extension.extensible import Extensible, ExtensionModule -class ExternalDataTool(Extensible): - pass + +class ExternalDataTool(Extensible, ABC): + """ + The base class of external data tool. + """ + + module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL + + def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + super().__init__(tenant_id, config) + + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + super().validate_form_schema(config) + + # implement your own validation logic here + + @abstractmethod + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + raise NotImplementedError diff --git a/api/core/external_data_tool/external_data_tool_factory.py b/api/core/external_data_tool/external_data_tool_factory.py new file mode 100644 index 00000000000000..239e566f53d219 --- /dev/null +++ b/api/core/external_data_tool/external_data_tool_factory.py @@ -0,0 +1,12 @@ +from core.extension.extensible import ExtensionModule +from extensions.ext_code_based_extension import code_based_extension + + +class ExternalDataToolFactory: + + def __init__(self, name: str, tenant_id: str, config: dict) -> None: + self.__extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + self.__extension_instance = self.__extension_class(tenant_id, config) + + def validate_config(self, config: dict) -> None: + self.__extension_class.validate_config(config) diff --git a/api/core/moderation/api_based/api_based.py b/api/core/moderation/api_based/api_based.py index f11d0f5dfdd7a7..027a584560fdf7 100644 --- a/api/core/moderation/api_based/api_based.py +++ b/api/core/moderation/api_based/api_based.py @@ -1,13 +1,43 @@ +from typing import Optional + from core.moderation.base import Moderation +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension class ApiBasedModeration(Moderation): - type = "api_based" + name: str = "api" @classmethod - def validate_config(cls, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + super().validate_config(tenant_id, config) + cls._validate_inputs_and_outputs_config(config, False) + api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: raise ValueError("api_based_extension_id is required") + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("api_based_extension_id is invalid") + + def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + pass + + def moderation_for_outputs(self, text: str): + pass + + - cls._validate_inputs_and_outputs_config(config, False) \ No newline at end of file diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 60f2466840c197..2cf697b041cb49 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,21 +1,38 @@ -from abc import abstractclassmethod +from abc import ABC, abstractmethod +from typing import Optional -from core.extension.extensible import Extensible +from core.extension.extensible import Extensible, ExtensionModule -class Moderation(Extensible): +class Moderation(Extensible, ABC): + """ + The base class of moderation. + """ + module: ExtensionModule = ExtensionModule.MODERATION - @abstractclassmethod - def validate_config(self, config: dict) -> None: - pass + def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + super().__init__(tenant_id, config) - @abstractclassmethod - def moderation_for_inputs(self, config: dict): - pass + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + super().validate_form_schema(config) + + # implement your own validation logic here + + @abstractmethod + def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + raise NotImplementedError - @abstractclassmethod - def moderation_for_outputs(self, config: dict): - pass + @abstractmethod + def moderation_for_outputs(self, text: str): + raise NotImplementedError @classmethod def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: diff --git a/api/core/moderation/cloud_service/cloud_service.py b/api/core/moderation/cloud_service/cloud_service.py index c404ab27a92394..ced2b118b333fa 100644 --- a/api/core/moderation/cloud_service/cloud_service.py +++ b/api/core/moderation/cloud_service/cloud_service.py @@ -1,5 +1,13 @@ +from typing import Optional + from core.moderation.base import Moderation -from core.extension.extensible import Extensible + class CloudServiceModeration(Moderation): - type = "cloud_service" \ No newline at end of file + name: str = "cloud_service" + + def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + pass + + def moderation_for_outputs(self, text: str): + pass \ No newline at end of file diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 7ce4f9542c9db5..cc9f4feca98381 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,13 +1,30 @@ +from typing import Optional + from core.moderation.base import Moderation + class KeywordsModeration(Moderation): - type = "keywords" + name: str = "keywords" @classmethod - def validate_config(cls, config): + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + super().validate_config(tenant_id, config) + cls._validate_inputs_and_outputs_config(config, True) + keywords = config.get("keywords") if not keywords: raise ValueError("keywords is required") - - cls._validate_inputs_and_outputs_config(config, True) - \ No newline at end of file + + def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + pass + + def moderation_for_outputs(self, text: str): + pass + diff --git a/api/core/moderation/moderation_factory.py b/api/core/moderation/moderation_factory.py index 6b22ea42b90afb..d9bc7a395c020d 100644 --- a/api/core/moderation/moderation_factory.py +++ b/api/core/moderation/moderation_factory.py @@ -4,8 +4,9 @@ class ModerationFactory: - def __init__(self, name: str): - self.__moderation_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + def __init__(self, name: str, tenant_id: str, config: dict) -> None: + self.__extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + self.__extension_instance = self.__extension_class(tenant_id, config) def validate_config(self, config: dict) -> None: - self.__moderation_class.validate_config(config) + self.__extension_class.validate_config(config) diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py index d21b541bc777f3..b0b2f3c9319a43 100644 --- a/api/core/moderation/openai/openai.py +++ b/api/core/moderation/openai/openai.py @@ -1,8 +1,25 @@ +from typing import Optional + from core.moderation.base import Moderation + class OpenAIModeration(Moderation): - type = "openai" + name: str = "openai_moderation" @classmethod - def validate_config(cls, config: dict): - cls._validate_inputs_and_outputs_config(config, True) \ No newline at end of file + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + super().validate_config(tenant_id, config) + cls._validate_inputs_and_outputs_config(config, True) + + def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + pass + + def moderation_for_outputs(self, text: str): + pass \ No newline at end of file From 05c065a39cfb6ea2e12271ffeeb25ac1d2c9d6b7 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 14:49:44 +0800 Subject: [PATCH 13/57] feat: refactor api based extensions --- api/core/external_data_tool/{api_based => api}/__builtin__ | 0 api/core/external_data_tool/{api_based => api}/__init__.py | 0 .../external_data_tool/{api_based/api_based.py => api/api.py} | 2 +- .../{external_data_tool_factory.py => factory.py} | 0 api/core/moderation/{api_based => api}/__builtin__ | 0 api/core/moderation/{api_based => api}/__init__.py | 0 api/core/moderation/{api_based/api_based.py => api/api.py} | 2 +- api/core/moderation/{moderation_factory.py => factory.py} | 0 api/services/app_model_config_service.py | 2 +- 9 files changed, 3 insertions(+), 3 deletions(-) rename api/core/external_data_tool/{api_based => api}/__builtin__ (100%) rename api/core/external_data_tool/{api_based => api}/__init__.py (100%) rename api/core/external_data_tool/{api_based/api_based.py => api/api.py} (97%) rename api/core/external_data_tool/{external_data_tool_factory.py => factory.py} (100%) rename api/core/moderation/{api_based => api}/__builtin__ (100%) rename api/core/moderation/{api_based => api}/__init__.py (100%) rename api/core/moderation/{api_based/api_based.py => api/api.py} (97%) rename api/core/moderation/{moderation_factory.py => factory.py} (100%) diff --git a/api/core/external_data_tool/api_based/__builtin__ b/api/core/external_data_tool/api/__builtin__ similarity index 100% rename from api/core/external_data_tool/api_based/__builtin__ rename to api/core/external_data_tool/api/__builtin__ diff --git a/api/core/external_data_tool/api_based/__init__.py b/api/core/external_data_tool/api/__init__.py similarity index 100% rename from api/core/external_data_tool/api_based/__init__.py rename to api/core/external_data_tool/api/__init__.py diff --git a/api/core/external_data_tool/api_based/api_based.py b/api/core/external_data_tool/api/api.py similarity index 97% rename from api/core/external_data_tool/api_based/api_based.py rename to api/core/external_data_tool/api/api.py index 6cf6a9e6918cce..d5d7cd07626b10 100644 --- a/api/core/external_data_tool/api_based/api_based.py +++ b/api/core/external_data_tool/api/api.py @@ -5,7 +5,7 @@ from models.api_based_extension import APIBasedExtension -class ApiBasedExternalDataTool(ExternalDataTool): +class ApiExternalDataTool(ExternalDataTool): name: str = "api" @classmethod diff --git a/api/core/external_data_tool/external_data_tool_factory.py b/api/core/external_data_tool/factory.py similarity index 100% rename from api/core/external_data_tool/external_data_tool_factory.py rename to api/core/external_data_tool/factory.py diff --git a/api/core/moderation/api_based/__builtin__ b/api/core/moderation/api/__builtin__ similarity index 100% rename from api/core/moderation/api_based/__builtin__ rename to api/core/moderation/api/__builtin__ diff --git a/api/core/moderation/api_based/__init__.py b/api/core/moderation/api/__init__.py similarity index 100% rename from api/core/moderation/api_based/__init__.py rename to api/core/moderation/api/__init__.py diff --git a/api/core/moderation/api_based/api_based.py b/api/core/moderation/api/api.py similarity index 97% rename from api/core/moderation/api_based/api_based.py rename to api/core/moderation/api/api.py index 027a584560fdf7..e8ee9f4070f707 100644 --- a/api/core/moderation/api_based/api_based.py +++ b/api/core/moderation/api/api.py @@ -5,7 +5,7 @@ from models.api_based_extension import APIBasedExtension -class ApiBasedModeration(Moderation): +class ApiModeration(Moderation): name: str = "api" @classmethod diff --git a/api/core/moderation/moderation_factory.py b/api/core/moderation/factory.py similarity index 100% rename from api/core/moderation/moderation_factory.py rename to api/core/moderation/factory.py diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 5087e03057c345..925bce3980c509 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,7 +1,7 @@ import re import uuid -from core.moderation.moderation_factory import ModerationFactory +from core.moderation.factory import ModerationFactory from core.prompt.prompt_transform import AppMode from core.agent.agent_executor import PlanningStrategy from core.model_providers.model_provider_factory import ModelProviderFactory From d526d861adbea6210b341149602f8098fb632fd9 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 16:02:39 +0800 Subject: [PATCH 14/57] feat: add api request support --- .../api_based_extension_requestor.py | 50 ++++++++++++++++++- api/core/external_data_tool/api/api.py | 27 +++++++++- api/core/external_data_tool/base.py | 7 ++- api/core/external_data_tool/factory.py | 9 +++- api/models/api_based_extension.py | 7 +++ 5 files changed, 94 insertions(+), 6 deletions(-) diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 0d80d15368d41c..dd89015376db84 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,2 +1,50 @@ +import requests + +from models.api_based_extension import APIBasedExtensionPoint + + class APIBasedExtensionRequestor: - pass + timeout: (int, int) = (5, 60) + """timeout for request connect and read""" + + def __init__(self, api_endpoint: str, api_key: str) -> None: + self.api_endpoint = api_endpoint + self.api_key = api_key + + def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: + """ + Request the api. + + :param point: the api point + :param params: the request params + :return: the response json + """ + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(self.api_key) + } + + url = self.api_endpoint + + try: + response = requests.request( + method='POST', + url=url, + json={ + 'point': point.value, + 'params': params + }, + headers=headers, + timeout=self.timeout + ) + + # TODO proxy support for security + except requests.exceptions.Timeout: + raise ValueError("request timeout") + except requests.exceptions.ConnectionError: + raise ValueError("request connection error") + + if response.status_code != 200: + raise ValueError("request error, status_code: {}, content: {}".format(response.status_code, response.content)) + + return response.json() diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index d5d7cd07626b10..ecfeada94e3bd1 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,8 +1,10 @@ from typing import Optional +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.external_data_tool.base import ExternalDataTool +from core.helper import encrypter from extensions.ext_database import db -from models.api_based_extension import APIBasedExtension +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class ApiExternalDataTool(ExternalDataTool): @@ -53,5 +55,26 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - # todo request api + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=self.tenant_id, + token=api_based_extension.api_key + ) + # request api + requestor = APIBasedExtensionRequestor( + api_endpoint=api_based_extension.api_endpoint, + api_key=api_based_extension.api_key + ) + + response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ + 'app_id': self.app_id, + 'tool_variable': self.variable, + 'inputs': inputs, + 'query': query + }) + + if 'result' not in response_json: + raise ValueError("result not found in response") + + return response_json['result'] diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 288c2beae3034e..7c41c286cdef23 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -11,8 +11,13 @@ class ExternalDataTool(Extensible, ABC): module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL - def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + app_id: str + variable: str + + def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: super().__init__(tenant_id, config) + self.app_id = app_id + self.variable = variable @classmethod def validate_config(cls, tenant_id: str, config: dict) -> None: diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 239e566f53d219..d544d3cf3e25ae 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -4,9 +4,14 @@ class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, config: dict) -> None: + def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: self.__extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - self.__extension_instance = self.__extension_class(tenant_id, config) + self.__extension_instance = self.__extension_class( + tenant_id=tenant_id, + app_id=app_id, + variable=variable, + config=config + ) def validate_config(self, config: dict) -> None: self.__extension_class.validate_config(config) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index f0bcbc7b5ae595..9468f0897ce148 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,14 @@ +import enum + from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db + +class APIBasedExtensionPoint(enum.Enum): + APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' + + class APIBasedExtension(db.Model): __tablename__ = 'api_based_extensions' __table_args__ = ( From 784d73eaff1028d38b9ff875d2b261ef8b690d91 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 16:07:20 +0800 Subject: [PATCH 15/57] feat: add comments --- api/core/external_data_tool/api/api.py | 7 ++++++- api/core/external_data_tool/base.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index ecfeada94e3bd1..221aea19e10171 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -8,7 +8,12 @@ class ApiExternalDataTool(ExternalDataTool): + """ + The api external data tool. + """ + name: str = "api" + """the unique name of external data tool""" @classmethod def validate_config(cls, tenant_id: str, config: dict) -> None: @@ -64,7 +69,7 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: # request api requestor = APIBasedExtensionRequestor( api_endpoint=api_based_extension.api_endpoint, - api_key=api_based_extension.api_key + api_key=api_key ) response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 7c41c286cdef23..195ee3fa9c6f7f 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -12,7 +12,9 @@ class ExternalDataTool(Extensible, ABC): module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL app_id: str + """the id of app""" variable: str + """the tool variable name of app tool""" def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: super().__init__(tenant_id, config) From abe576aea319d81a7cbec479d73484e1594aaae2 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 16:43:30 +0800 Subject: [PATCH 16/57] feat: add is_external_data_tools_valid --- api/core/external_data_tool/factory.py | 30 ++++++++-- api/core/moderation/factory.py | 26 +++++++-- api/services/app_model_config_service.py | 73 ++++++++++++++++++------ 3 files changed, 102 insertions(+), 27 deletions(-) diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index d544d3cf3e25ae..5b90d8b63cc511 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -5,13 +7,33 @@ class ExternalDataToolFactory: def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: - self.__extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - self.__extension_instance = self.__extension_class( + extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + self.__extension_instance = extension_class( tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) - def validate_config(self, config: dict) -> None: - self.__extension_class.validate_config(config) + @classmethod + def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param name: the name of external data tool + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + extension_class.validate_config(tenant_id, config) + + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + return self.__extension_instance.query(inputs, query) diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index d9bc7a395c020d..33293090592188 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -5,8 +7,24 @@ class ModerationFactory: def __init__(self, name: str, tenant_id: str, config: dict) -> None: - self.__extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - self.__extension_instance = self.__extension_class(tenant_id, config) + extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + self.__extension_instance = extension_class(tenant_id, config) + + @classmethod + def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param name: the name of extension + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + extension_class.validate_config(tenant_id, config) + + def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + return self.__extension_instance.moderation_for_inputs(inputs, query) - def validate_config(self, config: dict) -> None: - self.__extension_class.validate_config(config) + def moderation_for_outputs(self, text: str): + return self.__extension_instance.moderation_for_outputs(text) \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 925bce3980c509..1d003a27904fb5 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,6 +1,7 @@ import re import uuid +from core.external_data_tool.factory import ExternalDataToolFactory from core.moderation.factory import ModerationFactory from core.prompt.prompt_transform import AppMode from core.agent.agent_executor import PlanningStrategy @@ -14,8 +15,8 @@ class AppModelConfigService: - @staticmethod - def is_dataset_exists(account: Account, dataset_id: str) -> bool: + @classmethod + def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool: # verify if the dataset ID exists dataset = DatasetService.get_dataset(dataset_id) @@ -27,8 +28,8 @@ def is_dataset_exists(account: Account, dataset_id: str) -> bool: return True - @staticmethod - def validate_model_completion_params(cp: dict, model_name: str) -> dict: + @classmethod + def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict: # 6. model.completion_params if not isinstance(cp, dict): raise ValueError("model.completion_params must be of object type") @@ -74,8 +75,8 @@ def validate_model_completion_params(cp: dict, model_name: str) -> dict: return filtered_cp - @staticmethod - def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict: + @classmethod + def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict: # opening_statement if 'opening_statement' not in config or not config["opening_statement"]: config["opening_statement"] = "" @@ -187,7 +188,7 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: if 'completion_params' not in config["model"]: raise ValueError("model.completion_params is required") - config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params( + config["model"]["completion_params"] = cls.validate_model_completion_params( config["model"]["completion_params"], config["model"]["name"] ) @@ -304,17 +305,20 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: except ValueError: raise ValueError("id in dataset must be of UUID type") - if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]): + if not cls.is_dataset_exists(account, tool_item["id"]): raise ValueError("Dataset ID does not exist, please check your permission.") # dataset_query_variable - AppModelConfigService.is_dataset_query_variable_valid(config, mode) + cls.is_dataset_query_variable_valid(config, mode) # advanced prompt validation - AppModelConfigService.is_advanced_prompt_valid(config, mode) + cls.is_advanced_prompt_valid(config, mode) + + # external data tools validation + cls.is_external_data_tools_valid(tenant_id, config) # moderation validation - AppModelConfigService.is_moderation_valid(config) + cls.is_moderation_valid(tenant_id, config) # Filter out extra parameters filtered_config = { @@ -343,8 +347,8 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: return filtered_config - @staticmethod - def is_moderation_valid(config): + @classmethod + def is_moderation_valid(cls, tenant_id: str, config: dict): if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: config["sensitive_word_avoidance"] = { "enabled": False @@ -363,12 +367,43 @@ def is_moderation_valid(config): raise ValueError("sensitive_word_avoidance.type is required") type = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["configs"] + config = config["sensitive_word_avoidance"]["config"] + + ModerationFactory.validate_config( + name=type, + tenant_id=tenant_id, + config=config + ) + + @classmethod + def is_external_data_tools_valid(cls, tenant_id: str, config: dict): + if 'external_data_tools' not in config or not config["external_data_tools"]: + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + type = tool["type"] + config = tool["config"] - ModerationFactory(type).validate_config(config) + ExternalDataToolFactory.validate_config( + name=type, + tenant_id=tenant_id, + config=config + ) - @staticmethod - def is_dataset_query_variable_valid(config: dict, mode: str) -> None: + @classmethod + def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: # Only check when mode is completion if mode != 'completion': return @@ -383,8 +418,8 @@ def is_dataset_query_variable_valid(config: dict, mode: str) -> None: raise ValueError("Dataset query variable is required when dataset is exist") - @staticmethod - def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: + @classmethod + def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type if 'prompt_type' not in config or not config["prompt_type"]: config["prompt_type"] = "simple" From b68b14646636b1186d5fb625e3adced71926e57e Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 17:43:17 +0800 Subject: [PATCH 17/57] fix: circle import error --- api/core/extension/extensible.py | 9 --------- api/core/extension/extension.py | 7 +++++++ api/core/external_data_tool/base.py | 5 ++--- api/core/external_data_tool/factory.py | 1 + api/core/moderation/base.py | 22 +++++++++++++++++++--- api/core/moderation/factory.py | 18 ++++++++++++++++++ 6 files changed, 47 insertions(+), 15 deletions(-) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 46c9f776b6a34a..2e879578bf7357 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -8,8 +8,6 @@ from pydantic import BaseModel -from extensions.ext_code_based_extension import code_based_extension - class ExtensionModule(enum.Enum): MODERATION = 'moderation' @@ -36,13 +34,6 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: self.tenant_id = tenant_id self.config = config - @classmethod - def validate_form_schema(cls, config: dict) -> None: - module_extension = code_based_extension.module_extension(cls.module, cls.name) - form_schema = module_extension.form_schema - - # TODO validate form_schema - @classmethod def scan_extensions(cls): extensions = {} diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 494f46d017cc3b..845484cb1a1da6 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -39,3 +39,10 @@ def module_extension(self, module: ExtensionModule, extension_name: str) -> Modu def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) return module_extension.extension_class + + def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: + module_extension = self.module_extension(module, extension_name) + form_schema = module_extension.form_schema + + # TODO validate form_schema + diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 195ee3fa9c6f7f..1c181ff3c56c53 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -22,6 +22,7 @@ def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[ self.variable = variable @classmethod + @abstractmethod def validate_config(cls, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. @@ -30,9 +31,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ - super().validate_form_schema(config) - - # implement your own validation logic here + raise NotImplementedError @abstractmethod def query(self, inputs: dict, query: Optional[str] = None) -> str: diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 5b90d8b63cc511..979f243af65f61 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -25,6 +25,7 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ + code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) extension_class.validate_config(tenant_id, config) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 2cf697b041cb49..d0a94b4358e19d 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -14,6 +14,7 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: super().__init__(tenant_id, config) @classmethod + @abstractmethod def validate_config(cls, tenant_id: str, config: dict) -> None: """ Validate the incoming form config data. @@ -22,16 +23,31 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ - super().validate_form_schema(config) - - # implement your own validation logic here + raise NotImplementedError @abstractmethod def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + """ + Moderation for inputs. + After the user inputs, this method will be called to perform sensitive content review + on the user inputs and return the processed results. + + :param inputs: user inputs + :param query: query string (required in chat app) + :return: + """ raise NotImplementedError @abstractmethod def moderation_for_outputs(self, text: str): + """ + Moderation for outputs. + When LLM outputs content, the front end will pass the output content (may be segmented) + to this method for sensitive content review, and the output content will be shielded if the review fails. + + :param text: LLM output content + :return: + """ raise NotImplementedError @classmethod diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 33293090592188..480a30024ce8ad 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -20,11 +20,29 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ + code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) extension_class.validate_config(tenant_id, config) def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + """ + Moderation for inputs. + After the user inputs, this method will be called to perform sensitive content review + on the user inputs and return the processed results. + + :param inputs: user inputs + :param query: query string (required in chat app) + :return: + """ return self.__extension_instance.moderation_for_inputs(inputs, query) def moderation_for_outputs(self, text: str): + """ + Moderation for outputs. + When LLM outputs content, the front end will pass the output content (may be segmented) + to this method for sensitive content review, and the output content will be shielded if the review fails. + + :param text: LLM output content + :return: + """ return self.__extension_instance.moderation_for_outputs(text) \ No newline at end of file From 0a791ce4d497f30153a143e3bff802faabccb4f3 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 1 Nov 2023 18:11:10 +0800 Subject: [PATCH 18/57] feat: remove builtin extension from list --- api/services/code_based_extension_service.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index ea804e590ccf8f..38d8d112b513cb 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -9,6 +9,5 @@ def get_code_based_extension(module: str) -> list[dict]: return [{ 'name': module_extension.name, 'label': module_extension.label, - 'form_schema': module_extension.form_schema, - 'builtin': module_extension.builtin - } for module_extension in module_extensions] + 'form_schema': module_extension.form_schema + } for module_extension in module_extensions if not module_extension.builtin] From fc6ab60e4f5bf6b2c92bbff805f5c678b3aff04d Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Wed, 1 Nov 2023 21:17:44 +0800 Subject: [PATCH 19/57] add implements. --- api/core/completion.py | 47 ++++++++++++++---------- api/core/moderation/api/api.py | 1 - api/core/moderation/base.py | 3 ++ api/core/moderation/factory.py | 2 + api/core/moderation/keywords/keywords.py | 37 +++++++++++++++---- api/core/moderation/openai/openai.py | 46 +++++++++++++++++++++-- api/services/app_model_config_service.py | 2 +- 7 files changed, 106 insertions(+), 32 deletions(-) diff --git a/api/core/completion.py b/api/core/completion.py index 57e18199271ccb..062c25df45ef13 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -6,7 +6,6 @@ from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler -from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ @@ -18,6 +17,8 @@ from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from models.model import App, AppModelConfig, Account, Conversation, EndUser +from core.moderation.base import ModerationException +from core.moderation.factory import ModerationFactory class Completion: @@ -78,24 +79,22 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer try: # parse sensitive_word_avoidance_chain chain_callback = MainChainGatherCallbackHandler(conversation_message_task) - sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain( - final_model_instance, [chain_callback]) - if sensitive_word_avoidance_chain: - try: - query = sensitive_word_avoidance_chain.run(query) - except SensitiveWordAvoidanceError as ex: - cls.run_final_llm( - model_instance=final_model_instance, - mode=app.mode, - app_model_config=app_model_config, - query=query, - inputs=inputs, - agent_execute_result=None, - conversation_message_task=conversation_message_task, - memory=memory, - fake_response=ex.message - ) - return + + try: + cls.moderation_for_inputs(app.tenant_id, app_model_config, inputs, query) + except ModerationException as e: + cls.run_final_llm( + model_instance=final_model_instance, + mode=app.mode, + app_model_config=app_model_config, + query=query, + inputs=inputs, + agent_execute_result=None, + conversation_message_task=conversation_message_task, + memory=memory, + fake_response=str(e) + ) + return # get agent executor agent_executor = orchestrator_rule_parser.to_agent_executor( @@ -142,6 +141,16 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer logging.warning(f'ChunkedEncodingError: {e}') conversation_message_task.end() return + + @classmethod + def moderation_for_inputs(cls, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): + if not app_model_config.sensitive_word_avoidance_dict['enabled']: + return + + type = app_model_config.sensitive_word_avoidance_dict['type'] + + moderation = ModerationFactory(type, tenant_id, app_model_config.sensitive_word_avoidance_dict['configs']) + moderation.moderation_for_inputs(inputs, query) @classmethod def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index e8ee9f4070f707..6df772c2f66486 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -17,7 +17,6 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ - super().validate_config(tenant_id, config) cls._validate_inputs_and_outputs_config(config, False) api_based_extension_id = config.get("api_based_extension_id") diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index d0a94b4358e19d..d6508ba98337b4 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -76,3 +76,6 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r if outputs_configs_enabled and not outputs_configs.get("preset_response"): raise ValueError("outputs_configs.preset_response is required") + +class ModerationException(Exception): + pass \ No newline at end of file diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 480a30024ce8ad..559c72662941d8 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,10 +1,12 @@ from typing import Optional from core.extension.extensible import ExtensionModule +from core.moderation.base import Moderation from extensions.ext_code_based_extension import code_based_extension class ModerationFactory: + __extension_instance: Moderation def __init__(self, name: str, tenant_id: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index cc9f4feca98381..6320d803a69747 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,6 +1,6 @@ from typing import Optional -from core.moderation.base import Moderation +from core.moderation.base import Moderation, ModerationException class KeywordsModeration(Moderation): @@ -15,16 +15,39 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ - super().validate_config(tenant_id, config) cls._validate_inputs_and_outputs_config(config, True) - keywords = config.get("keywords") - if not keywords: + if not config.get("keywords"): raise ValueError("keywords is required") def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - pass + if not self.config['inputs_configs']['enabled']: + return - def moderation_for_outputs(self, text: str): - pass + if query: + inputs['query__'] = query + + keywords_list = self.config['keywords'].split('\n') + preset_response = self.config['inputs_configs']['preset_response'] + self._is_violated(inputs, preset_response, keywords_list) + + def moderation_for_outputs(self, text: str): + if not self.config['outputs_configs']['enabled']: + return + + keywords_list = self.config['keywords'].split('\n') + preset_response = self.config['outputs_configs']['preset_response'] + + self._is_violated({'text': text}, preset_response, keywords_list) + + def _is_violated(self, inputs: dict, preset_response: str, keywords_list: list): + for value in inputs.values(): + if self._check_keywords_in_text(keywords_list, value): + raise ModerationException(preset_response) + + def _check_keywords_in_text(self, keywords_list, text): + for keyword in keywords_list: + if keyword in text: + return True + return False diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py index b0b2f3c9319a43..857ef5b9c35b92 100644 --- a/api/core/moderation/openai/openai.py +++ b/api/core/moderation/openai/openai.py @@ -1,6 +1,11 @@ +import openai +import json from typing import Optional -from core.moderation.base import Moderation +from core.helper.encrypter import decrypt_token +from core.moderation.base import Moderation, ModerationException +from extensions.ext_database import db +from models.provider import Provider class OpenAIModeration(Moderation): @@ -15,11 +20,44 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ - super().validate_config(tenant_id, config) cls._validate_inputs_and_outputs_config(config, True) def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - pass + if not self.config['inputs_configs']['enabled']: + return + + preset_response = self.config['inputs_configs']['preset_response'] + if query: + inputs['query__'] = query + + self._is_violated(inputs, preset_response) def moderation_for_outputs(self, text: str): - pass \ No newline at end of file + if not self.config['outputs_configs']['enabled']: + return + + preset_response = self.config['inputs_configs']['preset_response'] + + self._is_violated({ 'text': text }, preset_response) + + def _is_violated(self, inputs: dict, preset_response: str): + + openai_api_key = self._get_openai_api_key() + moderation_result = openai.Moderation.create(input=list(inputs.values()), api_key=openai_api_key) + + for result in moderation_result.results: + if result['flagged']: + raise ModerationException(preset_response) + + def _get_openai_api_key(self) -> str: + provider = db.session.query(Provider) \ + .filter_by(tenant_id=self.tenant_id) \ + .filter_by(provider_name="openai") \ + .first() + + if not provider: + raise ValueError("openai provider is not configured") + + encrypted_config = json.loads(provider.encrypted_config) + + return decrypt_token(self.tenant_id, encrypted_config['openai_api_key']) \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 1d003a27904fb5..580a7a1f55bc58 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -367,7 +367,7 @@ def is_moderation_valid(cls, tenant_id: str, config: dict): raise ValueError("sensitive_word_avoidance.type is required") type = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] + config = config["sensitive_word_avoidance"]["configs"] ModerationFactory.validate_config( name=type, From 305f49e4bbc025a82325f8ad332dd0461a49b58e Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 2 Nov 2023 06:44:09 +0800 Subject: [PATCH 20/57] feat: add external_data_tool query in generate --- .../chain/sensitive_word_avoidance_chain.py | 92 ------------------- api/core/completion.py | 75 ++++++++++++++- .../api_based_extension_requestor.py | 15 ++- api/core/orchestrator_rule_parser.py | 47 ---------- api/fields/app_fields.py | 1 + ...e_add_external_data_tools_in_app_model_.py | 34 +++++++ api/models/model.py | 42 +++------ api/services/app_model_config_service.py | 1 + 8 files changed, 133 insertions(+), 174 deletions(-) delete mode 100644 api/core/chain/sensitive_word_avoidance_chain.py create mode 100644 api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py diff --git a/api/core/chain/sensitive_word_avoidance_chain.py b/api/core/chain/sensitive_word_avoidance_chain.py deleted file mode 100644 index 62d58542751b2f..00000000000000 --- a/api/core/chain/sensitive_word_avoidance_chain.py +++ /dev/null @@ -1,92 +0,0 @@ -import enum -import logging -from typing import List, Dict, Optional, Any - -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain -from pydantic import BaseModel - -from core.model_providers.error import LLMBadRequestError -from core.model_providers.model_factory import ModelFactory -from core.model_providers.models.llm.base import BaseLLM -from core.model_providers.models.moderation import openai_moderation - - -class SensitiveWordAvoidanceRule(BaseModel): - class Type(enum.Enum): - MODERATION = "moderation" - KEYWORDS = "keywords" - - type: Type - canned_response: str = 'Your content violates our usage policy. Please revise and try again.' - extra_params: dict = {} - - -class SensitiveWordAvoidanceChain(Chain): - input_key: str = "input" #: :meta private: - output_key: str = "output" #: :meta private: - - model_instance: BaseLLM - sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule - - @property - def _chain_type(self) -> str: - return "sensitive_word_avoidance_chain" - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return output key. - - :meta private: - """ - return [self.output_key] - - def _check_sensitive_word(self, text: str) -> bool: - for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []): - if word in text: - return False - return True - - def _check_moderation(self, text: str) -> bool: - moderation_model_instance = ModelFactory.get_moderation_model( - tenant_id=self.model_instance.model_provider.provider.tenant_id, - model_provider_name='openai', - model_name=openai_moderation.DEFAULT_MODEL - ) - - try: - return moderation_model_instance.run(text=text) - except Exception as ex: - logging.exception(ex) - raise LLMBadRequestError('Rate limit exceeded, please try again later.') - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - text = inputs[self.input_key] - - if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS: - result = self._check_sensitive_word(text) - else: - result = self._check_moderation(text) - - if not result: - raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response) - - return {self.output_key: text} - - -class SensitiveWordAvoidanceError(Exception): - def __init__(self, message): - super().__init__(message) - self.message = message diff --git a/api/core/completion.py b/api/core/completion.py index 062c25df45ef13..21e2fd1ff1c1b0 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,12 +1,16 @@ +import concurrent import logging -from typing import Optional, List, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, List, Union, Tuple +from flask import current_app, Flask from requests.exceptions import ChunkedEncodingError from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.external_data_tool.factory import ExternalDataToolFactory from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory @@ -77,10 +81,10 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer ) try: - # parse sensitive_word_avoidance_chain chain_callback = MainChainGatherCallbackHandler(conversation_message_task) try: + # process sensitive_word_avoidance cls.moderation_for_inputs(app.tenant_id, app_model_config, inputs, query) except ModerationException as e: cls.run_final_llm( @@ -96,6 +100,17 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer ) return + # fill in variable inputs from external data tools if exists + external_data_tools = app_model_config.external_data_tools_list + if external_data_tools: + inputs = cls.fill_in_inputs_from_external_data_tools( + tenant_id=app.tenant_id, + app_id=app.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + # get agent executor agent_executor = orchestrator_rule_parser.to_agent_executor( conversation_message_task=conversation_message_task, @@ -151,7 +166,61 @@ def moderation_for_inputs(cls, tenant_id: str, app_model_config: AppModelConfig, moderation = ModerationFactory(type, tenant_id, app_model_config.sensitive_word_avoidance_dict['configs']) moderation.moderation_for_inputs(inputs, query) - + + @classmethod + def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict], + inputs: dict, query: str) -> dict: + """ + Fill in variable inputs from external data tools if exists. + + :param tenant_id: workspace id + :param app_id: app id + :param external_data_tools: external data tools configs + :param inputs: the inputs + :param query: the query + :return: the filled inputs + """ + results = {} + with ThreadPoolExecutor() as executor: + futures = {executor.submit( + cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool, inputs, query + ): tool for tool in external_data_tools} + for future in concurrent.futures.as_completed(futures): + tool_variable, result = future.result() + if tool_variable is not None: + results[tool_variable] = result + + inputs.update(results) + return inputs + + @classmethod + def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict, + inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]: + with flask_app.app_context(): + enabled = external_data_tool.get("enabled") + if not enabled: + return None, None + + tool_variable = external_data_tool.get("variable") + tool_type = external_data_tool.get("type") + tool_config = external_data_tool.get("config") + + external_data_tool_factory = ExternalDataToolFactory( + name=tool_type, + tenant_id=tenant_id, + app_id=app_id, + variable=tool_variable, + config=tool_config + ) + + # query external data tool + result = external_data_tool_factory.query( + inputs=inputs, + query=query + ) + + return tool_variable, result + @classmethod def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str: if app.mode != 'completion': diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index dd89015376db84..6e5996ab73649d 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,3 +1,5 @@ +import os + import requests from models.api_based_extension import APIBasedExtensionPoint @@ -27,6 +29,14 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: url = self.api_endpoint try: + # proxy support for security + proxies = None + if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"): + proxies = { + 'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"), + 'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"), + } + response = requests.request( method='POST', url=url, @@ -35,10 +45,9 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: 'params': params }, headers=headers, - timeout=self.timeout + timeout=self.timeout, + proxies=proxies ) - - # TODO proxy support for security except requests.exceptions.Timeout: raise ValueError("request timeout") except requests.exceptions.ConnectionError: diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 2ba732ee3dd613..d057c160a2a4a5 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -11,7 +11,6 @@ from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule from core.conversation_message_task import ConversationMessageTask from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory @@ -125,52 +124,6 @@ def to_agent_executor(self, conversation_message_task: ConversationMessageTask, return chain - def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \ - -> Optional[SensitiveWordAvoidanceChain]: - """ - Convert app sensitive word avoidance config to chain - - :param model_instance: model instance - :param callbacks: callbacks for the chain - :param kwargs: - :return: - """ - sensitive_word_avoidance_rule = None - - if self.app_model_config.sensitive_word_avoidance_dict: - sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict - if sensitive_word_avoidance_config.get("enabled", False): - if sensitive_word_avoidance_config.get('type') == 'moderation': - sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( - type=SensitiveWordAvoidanceRule.Type.MODERATION, - canned_response=sensitive_word_avoidance_config.get("canned_response") - if sensitive_word_avoidance_config.get("canned_response") - else 'Your content violates our usage policy. Please revise and try again.', - ) - else: - sensitive_words = sensitive_word_avoidance_config.get("words", "") - if sensitive_words: - sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( - type=SensitiveWordAvoidanceRule.Type.KEYWORDS, - canned_response=sensitive_word_avoidance_config.get("canned_response") - if sensitive_word_avoidance_config.get("canned_response") - else 'Your content violates our usage policy. Please revise and try again.', - extra_params={ - 'sensitive_words': sensitive_words.split(','), - } - ) - - if sensitive_word_avoidance_rule: - return SensitiveWordAvoidanceChain( - model_instance=model_instance, - sensitive_word_avoidance_rule=sensitive_word_avoidance_rule, - output_key="sensitive_word_avoidance_output", - callbacks=callbacks, - **kwargs - ) - - return None - def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: """ Convert app agent tool configs to tools diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index fccfa5df306f5d..154b10ceb6f3d3 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -23,6 +23,7 @@ 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), + 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), 'model': fields.Raw(attribute='model_dict'), 'user_input_form': fields.Raw(attribute='user_input_form_list'), 'dataset_query_variable': fields.String, diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py new file mode 100644 index 00000000000000..9c8dcd4768f0fc --- /dev/null +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -0,0 +1,34 @@ +"""add external_data_tools in app model config + +Revision ID: a9836e3baeee +Revises: 968fff4c0ab9 +Create Date: 2023-11-02 04:04:57.609485 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'a9836e3baeee' +down_revision = '968fff4c0ab9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('sessions') + op.drop_table('third_party_provider_applies') + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('external_data_tools') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index d3f5c8135f1f99..d9c528ef4481ae 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -97,6 +97,7 @@ class AppModelConfig(db.Model): chat_prompt_config = db.Column(db.Text) completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) + external_data_tools = db.Column(db.Text) @property def app(self): @@ -135,6 +136,11 @@ def sensitive_word_avoidance_dict(self) -> dict: return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ else {"enabled": False, "words": [], "canned_response": []} + @property + def external_data_tools_list(self) -> list[dict]: + return json.loads(self.external_data_tools) if self.external_data_tools \ + else [] + @property def user_input_form_list(self) -> dict: return json.loads(self.user_input_form) if self.user_input_form else [] @@ -167,6 +173,7 @@ def to_dict(self) -> dict: "retriever_resource": self.retriever_resource_dict, "more_like_this": self.more_like_this_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, + "external_data_tools": self.external_data_tools_list, "model": self.model_dict, "user_input_form": self.user_input_form_list, "dataset_query_variable": self.dataset_query_variable, @@ -190,6 +197,7 @@ def from_model_config_dict(self, model_config: dict): self.more_like_this = json.dumps(model_config['more_like_this']) self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ if model_config.get('sensitive_word_avoidance') else None + self.external_data_tools = json.dumps(model_config['external_data_tools']) self.model = json.dumps(model_config['model']) self.user_input_form = json.dumps(model_config['user_input_form']) self.dataset_query_variable = model_config.get('dataset_query_variable') @@ -219,6 +227,7 @@ def copy(self): speech_to_text=self.speech_to_text, more_like_this=self.more_like_this, sensitive_word_avoidance=self.sensitive_word_avoidance, + external_data_tools=self.external_data_tools, model=self.model, user_input_form=self.user_input_form, dataset_query_variable=self.dataset_query_variable, @@ -332,41 +341,16 @@ def model_config(self): override_model_configs = json.loads(self.override_model_configs) if 'model' in override_model_configs: - model_config['model'] = override_model_configs['model'] - model_config['pre_prompt'] = override_model_configs['pre_prompt'] - model_config['agent_mode'] = override_model_configs['agent_mode'] - model_config['opening_statement'] = override_model_configs['opening_statement'] - model_config['suggested_questions'] = override_model_configs['suggested_questions'] - model_config['suggested_questions_after_answer'] = override_model_configs[ - 'suggested_questions_after_answer'] \ - if 'suggested_questions_after_answer' in override_model_configs else {"enabled": False} - model_config['speech_to_text'] = override_model_configs[ - 'speech_to_text'] \ - if 'speech_to_text' in override_model_configs else {"enabled": False} - model_config['more_like_this'] = override_model_configs['more_like_this'] \ - if 'more_like_this' in override_model_configs else {"enabled": False} - model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \ - if 'sensitive_word_avoidance' in override_model_configs \ - else {"enabled": False, "words": [], "canned_response": []} - model_config['user_input_form'] = override_model_configs['user_input_form'] + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(override_model_configs) + model_config = app_model_config.to_dict() else: model_config['configs'] = override_model_configs else: app_model_config = db.session.query(AppModelConfig).filter( AppModelConfig.id == self.app_model_config_id).first() - model_config['configs'] = app_model_config.configs - model_config['model'] = app_model_config.model_dict - model_config['pre_prompt'] = app_model_config.pre_prompt - model_config['agent_mode'] = app_model_config.agent_mode_dict - model_config['opening_statement'] = app_model_config.opening_statement - model_config['suggested_questions'] = app_model_config.suggested_questions_list - model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict - model_config['speech_to_text'] = app_model_config.speech_to_text_dict - model_config['retriever_resource'] = app_model_config.retriever_resource_dict - model_config['more_like_this'] = app_model_config.more_like_this_dict - model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict - model_config['user_input_form'] = app_model_config.user_input_form_list + model_config = app_model_config.to_dict() model_config['model_id'] = self.model_id model_config['provider'] = self.model_provider diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 580a7a1f55bc58..9f62734b007b5d 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -329,6 +329,7 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, "retriever_resource": config["retriever_resource"], "more_like_this": config["more_like_this"], "sensitive_word_avoidance": config["sensitive_word_avoidance"], + "external_data_tools": config["external_data_tools"], "model": { "provider": config["model"]["provider"], "name": config["model"]["name"], From f6866c8d0c83bc5a1772373d8ebeee3990da687c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 2 Nov 2023 10:11:48 +0800 Subject: [PATCH 21/57] fix: bug --- api/core/external_data_tool/api/api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 221aea19e10171..cdcf4c2ae9947f 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -24,8 +24,6 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: :param config: the form config data :return: """ - super().validate_config(tenant_id, config) - # own validation logic api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: From 343c075136d99d70b6046c35fa587c2525d6d3f8 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 2 Nov 2023 10:24:02 +0800 Subject: [PATCH 22/57] feat: optimize multi thread call api --- api/core/completion.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/api/core/completion.py b/api/core/completion.py index 21e2fd1ff1c1b0..909a955c0e1881 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,4 +1,5 @@ import concurrent +import json import logging from concurrent.futures import ThreadPoolExecutor from typing import Optional, List, Union, Tuple @@ -180,15 +181,33 @@ def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, ex :param query: the query :return: the filled inputs """ + # Group tools by type and config + grouped_tools = {} + for tool in external_data_tools: + if not tool.get("enabled"): + continue + + tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True)) + grouped_tools.setdefault(tool_key, []).append(tool) + results = {} with ThreadPoolExecutor() as executor: - futures = {executor.submit( - cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool, inputs, query - ): tool for tool in external_data_tools} + futures = {} + for tools in grouped_tools.values(): + # Only query the first tool in each group + first_tool = tools[0] + future = executor.submit( + cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool, + inputs, query + ) + for tool in tools: + futures[future] = tool + for future in concurrent.futures.as_completed(futures): - tool_variable, result = future.result() - if tool_variable is not None: - results[tool_variable] = result + tool_key, result = future.result() + if tool_key in grouped_tools: + for tool in grouped_tools[tool_key]: + results[tool['variable']] = result inputs.update(results) return inputs @@ -197,10 +216,6 @@ def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, ex def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict, inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]: with flask_app.app_context(): - enabled = external_data_tool.get("enabled") - if not enabled: - return None, None - tool_variable = external_data_tool.get("variable") tool_type = external_data_tool.get("type") tool_config = external_data_tool.get("config") @@ -219,7 +234,9 @@ def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, query=query ) - return tool_variable, result + tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True)) + + return tool_key, result @classmethod def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str: From 99dcb17e2525b1ad7d78b8c813f191b5954aaea6 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 2 Nov 2023 10:50:00 +0800 Subject: [PATCH 23/57] feat: optimize api response error --- .../extension/api_based_extension_requestor.py | 2 +- api/core/external_data_tool/api/api.py | 16 +++++++++++----- ...aeee_add_external_data_tools_in_app_model_.py | 2 -- api/services/completion_service.py | 3 ++- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 6e5996ab73649d..d35acd95bea317 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -54,6 +54,6 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}, content: {}".format(response.status_code, response.content)) + raise ValueError("request error, status_code: {}".format(response.status_code)) return response.json() diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index cdcf4c2ae9947f..32adee77fc1213 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -64,11 +64,17 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: token=api_based_extension.api_key ) - # request api - requestor = APIBasedExtensionRequestor( - api_endpoint=api_based_extension.api_endpoint, - api_key=api_key - ) + try: + # request api + requestor = APIBasedExtensionRequestor( + api_endpoint=api_based_extension.api_endpoint, + api_key=api_key + ) + except Exception as e: + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( + self.config.get('variable'), + e + )) response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ 'app_id': self.app_id, diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index 9c8dcd4768f0fc..9b452f75eed6d4 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -18,8 +18,6 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('sessions') - op.drop_table('third_party_provider_applies') with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 54b150d155fb82..df12a88718462a 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -201,7 +201,7 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m ) except ConversationTaskStoppedException: pass - except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError) as e: PubHandler.pub_error(user, generate_task_id, e) @@ -508,6 +508,7 @@ def handle_error(cls, result: dict): # handle errors llm_errors = { + 'ValueError': LLMBadRequestError, 'LLMBadRequestError': LLMBadRequestError, 'LLMAPIConnectionError': LLMAPIConnectionError, 'LLMAPIUnavailableError': LLMAPIUnavailableError, From c9adf399773bb09e23043dd2d80b118092ba165d Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 2 Nov 2023 10:54:12 +0800 Subject: [PATCH 24/57] feat: optimize api response error --- api/core/external_data_tool/api/api.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 32adee77fc1213..8896a00699e7cf 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -56,7 +56,9 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: ).first() if not api_based_extension: - raise ValueError("api_based_extension_id is invalid") + raise ValueError("[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid" + .format(self.config.get('variable'))) # decrypt api_key api_key = encrypter.decrypt_token( @@ -84,6 +86,7 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: }) if 'result' not in response_json: - raise ValueError("result not found in response") + raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" + .format(self.config.get('variable'))) return response_json['result'] From 93b928c8702af57ecd39aea5293a7a6dc69dff9b Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Thu, 2 Nov 2023 13:17:02 +0800 Subject: [PATCH 25/57] add moderation api. --- api/controllers/console/__init__.py | 2 +- api/controllers/console/moderation.py | 24 +++++++++++++++++++++++ api/core/moderation/base.py | 13 +++++++++++- api/core/moderation/factory.py | 4 ++-- api/core/moderation/keywords/keywords.py | 17 ++++++++++------ api/core/moderation/openai/openai.py | 15 +++++++++----- api/models/model.py | 2 +- api/services/moderation_service.py | 25 ++++++++++++++++++++++++ 8 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 api/controllers/console/moderation.py create mode 100644 api/services/moderation_service.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ac881dc126c0d0..503563ca957f12 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -6,7 +6,7 @@ api = ExternalApi(bp) # Import other controllers -from . import extension, setup, version, apikey, admin +from . import extension, setup, version, apikey, admin, moderation # Import app controllers from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio diff --git a/api/controllers/console/moderation.py b/api/controllers/console/moderation.py new file mode 100644 index 00000000000000..401b2afae9d78f --- /dev/null +++ b/api/controllers/console/moderation.py @@ -0,0 +1,24 @@ +from flask_restful import Resource, reqparse +from flask_login import current_user + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.login import login_required +from services.moderation_service import ModerationService + +class ModerationAPI(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, location='json') + parser.add_argument('text', type=str, required=True, location='json') + args = parser.parse_args() + + service = ModerationService() + return service.moderation_for_outputs(args['app_id'], current_user.current_tenant_id, args['text']) + +api.add_resource(ModerationAPI, '/moderation') \ No newline at end of file diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index d6508ba98337b4..54e73585c865ea 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,8 +1,19 @@ from abc import ABC, abstractmethod from typing import Optional +from pydantic import BaseModel +from enum import Enum from core.extension.extensible import Extensible, ExtensionModule +class ModerationOuputsAction(Enum): + DIRECT_OUTPUT = "direct_output" + OVERRIDED = "overrided" + +class ModerationOutputsResult(BaseModel): + flagged: bool = False + action: ModerationOuputsAction + preset_response: str = "" + text: str = "" class Moderation(Extensible, ABC): """ @@ -39,7 +50,7 @@ def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): raise NotImplementedError @abstractmethod - def moderation_for_outputs(self, text: str): + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: """ Moderation for outputs. When LLM outputs content, the front end will pass the output content (may be segmented) diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 559c72662941d8..18a72e1e7c899a 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,7 +1,7 @@ from typing import Optional from core.extension.extensible import ExtensionModule -from core.moderation.base import Moderation +from core.moderation.base import Moderation, ModerationOutputsResult from extensions.ext_code_based_extension import code_based_extension @@ -38,7 +38,7 @@ def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): """ return self.__extension_instance.moderation_for_inputs(inputs, query) - def moderation_for_outputs(self, text: str): + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: """ Moderation for outputs. When LLM outputs content, the front end will pass the output content (may be segmented) diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 6320d803a69747..90f4f52185b3f3 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,6 +1,6 @@ from typing import Optional -from core.moderation.base import Moderation, ModerationException +from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOuputsAction class KeywordsModeration(Moderation): @@ -30,21 +30,26 @@ def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): keywords_list = self.config['keywords'].split('\n') preset_response = self.config['inputs_configs']['preset_response'] - self._is_violated(inputs, preset_response, keywords_list) + if self._is_violated(inputs, preset_response, keywords_list): + raise ModerationException(preset_response) - def moderation_for_outputs(self, text: str): + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: if not self.config['outputs_configs']['enabled']: return keywords_list = self.config['keywords'].split('\n') preset_response = self.config['outputs_configs']['preset_response'] - self._is_violated({'text': text}, preset_response, keywords_list) + flagged = self._is_violated({'text': text}, preset_response, keywords_list) - def _is_violated(self, inputs: dict, preset_response: str, keywords_list: list): + return ModerationOutputsResult(flagged=flagged, action=ModerationOuputsAction.DIRECT_OUTPUT, preset_response=preset_response) + + def _is_violated(self, inputs: dict, preset_response: str, keywords_list: list) -> bool: for value in inputs.values(): if self._check_keywords_in_text(keywords_list, value): - raise ModerationException(preset_response) + return True + + return False def _check_keywords_in_text(self, keywords_list, text): for keyword in keywords_list: diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai/openai.py index 857ef5b9c35b92..bdf17fefc8670f 100644 --- a/api/core/moderation/openai/openai.py +++ b/api/core/moderation/openai/openai.py @@ -3,7 +3,7 @@ from typing import Optional from core.helper.encrypter import decrypt_token -from core.moderation.base import Moderation, ModerationException +from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOuputsAction from extensions.ext_database import db from models.provider import Provider @@ -30,15 +30,18 @@ def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): if query: inputs['query__'] = query - self._is_violated(inputs, preset_response) + if self._is_violated(inputs, preset_response): + raise ModerationException(preset_response) - def moderation_for_outputs(self, text: str): + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: if not self.config['outputs_configs']['enabled']: return preset_response = self.config['inputs_configs']['preset_response'] - self._is_violated({ 'text': text }, preset_response) + flagged = self._is_violated({ 'text': text }, preset_response) + + return ModerationOutputsResult(flagged=flagged, action=ModerationOuputsAction.DIRECT_OUTPUT, preset_response=preset_response) def _is_violated(self, inputs: dict, preset_response: str): @@ -47,7 +50,9 @@ def _is_violated(self, inputs: dict, preset_response: str): for result in moderation_result.results: if result['flagged']: - raise ModerationException(preset_response) + return True + + return False def _get_openai_api_key(self) -> str: provider = db.session.query(Provider) \ diff --git a/api/models/model.py b/api/models/model.py index d9c528ef4481ae..cefc275e35ccb9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -134,7 +134,7 @@ def more_like_this_dict(self) -> dict: @property def sensitive_word_avoidance_dict(self) -> dict: return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ - else {"enabled": False, "words": [], "canned_response": []} + else {"enabled": False, "type": "", "configs": []} @property def external_data_tools_list(self) -> list[dict]: diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py new file mode 100644 index 00000000000000..e8ef67cc994a58 --- /dev/null +++ b/api/services/moderation_service.py @@ -0,0 +1,25 @@ +import json + +from core.moderation.factory import ModerationFactory +from extensions.ext_database import db +from models.model import AppModelConfig, App + +class ModerationService: + + def moderation_for_outputs(self, app_id: str, tenant_id: str, text: str) -> dict: + + app_model_config = db.session.query(AppModelConfig) \ + .join(App, App.app_model_config_id == AppModelConfig.id) \ + .filter(App.id == app_id) \ + .filter(App.tenant_id == tenant_id) \ + .first() + + if not app_model_config: + raise ValueError("app model config not found") + + name = app_model_config.sensitive_word_avoidance_dict['type'] + config = app_model_config.sensitive_word_avoidance_dict['configs'] + + moderation = ModerationFactory(name, tenant_id, config) + json_str = moderation.moderation_for_outputs(text).json() + return json.loads(json_str) \ No newline at end of file From 513eb21ef8569c95dafd5b22d1ee53ca1f86b8a5 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Thu, 2 Nov 2023 14:53:33 +0800 Subject: [PATCH 26/57] update. --- api/core/moderation/{openai => openai_moderation}/__builtin__ | 0 api/core/moderation/{openai => openai_moderation}/__init__.py | 0 .../{openai/openai.py => openai_moderation/openai_moderation.py} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename api/core/moderation/{openai => openai_moderation}/__builtin__ (100%) rename api/core/moderation/{openai => openai_moderation}/__init__.py (100%) rename api/core/moderation/{openai/openai.py => openai_moderation/openai_moderation.py} (100%) diff --git a/api/core/moderation/openai/__builtin__ b/api/core/moderation/openai_moderation/__builtin__ similarity index 100% rename from api/core/moderation/openai/__builtin__ rename to api/core/moderation/openai_moderation/__builtin__ diff --git a/api/core/moderation/openai/__init__.py b/api/core/moderation/openai_moderation/__init__.py similarity index 100% rename from api/core/moderation/openai/__init__.py rename to api/core/moderation/openai_moderation/__init__.py diff --git a/api/core/moderation/openai/openai.py b/api/core/moderation/openai_moderation/openai_moderation.py similarity index 100% rename from api/core/moderation/openai/openai.py rename to api/core/moderation/openai_moderation/openai_moderation.py From 0a2cce6528ae9dbc87ca526931b167c52fa70b21 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Thu, 2 Nov 2023 16:17:19 +0800 Subject: [PATCH 27/57] refactor. --- api/controllers/console/moderation.py | 2 +- api/core/moderation/base.py | 5 ---- api/core/moderation/keywords/keywords.py | 27 ++++++++++--------- .../openai_moderation/openai_moderation.py | 19 ++++++------- api/services/moderation_service.py | 18 ++++++------- 5 files changed, 34 insertions(+), 37 deletions(-) diff --git a/api/controllers/console/moderation.py b/api/controllers/console/moderation.py index 401b2afae9d78f..b3c4374ae9cc0e 100644 --- a/api/controllers/console/moderation.py +++ b/api/controllers/console/moderation.py @@ -19,6 +19,6 @@ def post(self): args = parser.parse_args() service = ModerationService() - return service.moderation_for_outputs(args['app_id'], current_user.current_tenant_id, args['text']) + return service.moderation_for_outputs(args['app_id'], args['text']) api.add_resource(ModerationAPI, '/moderation') \ No newline at end of file diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 54e73585c865ea..ccb8a7caf58572 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -5,14 +5,9 @@ from core.extension.extensible import Extensible, ExtensionModule -class ModerationOuputsAction(Enum): - DIRECT_OUTPUT = "direct_output" - OVERRIDED = "overrided" class ModerationOutputsResult(BaseModel): flagged: bool = False - action: ModerationOuputsAction - preset_response: str = "" text: str = "" class Moderation(Extensible, ABC): diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 90f4f52185b3f3..b9a3e128b02421 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,6 +1,6 @@ from typing import Optional -from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOuputsAction +from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult class KeywordsModeration(Moderation): @@ -30,29 +30,30 @@ def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): keywords_list = self.config['keywords'].split('\n') preset_response = self.config['inputs_configs']['preset_response'] - if self._is_violated(inputs, preset_response, keywords_list): + if self._is_violated(inputs, keywords_list): raise ModerationException(preset_response) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: - if not self.config['outputs_configs']['enabled']: - return - - keywords_list = self.config['keywords'].split('\n') - preset_response = self.config['outputs_configs']['preset_response'] + flagged = False + preset_response = "" - flagged = self._is_violated({'text': text}, preset_response, keywords_list) + if self.config['outputs_configs']['enabled']: + keywords_list = self.config['keywords'].split('\n') + flagged = self._is_violated({'text': text}, keywords_list) + if flagged: + preset_response = self.config['outputs_configs']['preset_response'] - return ModerationOutputsResult(flagged=flagged, action=ModerationOuputsAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult(flagged=flagged, text=preset_response) - def _is_violated(self, inputs: dict, preset_response: str, keywords_list: list) -> bool: + def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): - if self._check_keywords_in_text(keywords_list, value): + if self._check_keywords_in_value(keywords_list, value): return True return False - def _check_keywords_in_text(self, keywords_list, text): + def _check_keywords_in_value(self, keywords_list, value): for keyword in keywords_list: - if keyword in text: + if keyword.lower() in value.lower(): return True return False diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index bdf17fefc8670f..01541341d15bc8 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -3,7 +3,7 @@ from typing import Optional from core.helper.encrypter import decrypt_token -from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOuputsAction +from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult from extensions.ext_database import db from models.provider import Provider @@ -30,20 +30,21 @@ def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): if query: inputs['query__'] = query - if self._is_violated(inputs, preset_response): + if self._is_violated(inputs): raise ModerationException(preset_response) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: - if not self.config['outputs_configs']['enabled']: - return - - preset_response = self.config['inputs_configs']['preset_response'] + flagged = False + preset_response = "" - flagged = self._is_violated({ 'text': text }, preset_response) + if self.config['outputs_configs']['enabled']: + flagged = self._is_violated({ 'text': text }) + if flagged: + preset_response = self.config['outputs_configs']['preset_response'] - return ModerationOutputsResult(flagged=flagged, action=ModerationOuputsAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult(flagged=flagged, text=preset_response) - def _is_violated(self, inputs: dict, preset_response: str): + def _is_violated(self, inputs: dict): openai_api_key = self._get_openai_api_key() moderation_result = openai.Moderation.create(input=list(inputs.values()), api_key=openai_api_key) diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index e8ef67cc994a58..ec23f4dc7b1d2c 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -6,13 +6,14 @@ class ModerationService: - def moderation_for_outputs(self, app_id: str, tenant_id: str, text: str) -> dict: + def moderation_for_outputs(self, app_id: str, text: str) -> dict: - app_model_config = db.session.query(AppModelConfig) \ - .join(App, App.app_model_config_id == AppModelConfig.id) \ - .filter(App.id == app_id) \ - .filter(App.tenant_id == tenant_id) \ - .first() + app = db.session.query(App).filter(App.id == app_id).first() + + if not app: + raise ValueError("app not found") + + app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app.app_model_config_id).first() if not app_model_config: raise ValueError("app model config not found") @@ -20,6 +21,5 @@ def moderation_for_outputs(self, app_id: str, tenant_id: str, text: str) -> dict name = app_model_config.sensitive_word_avoidance_dict['type'] config = app_model_config.sensitive_word_avoidance_dict['configs'] - moderation = ModerationFactory(name, tenant_id, config) - json_str = moderation.moderation_for_outputs(text).json() - return json.loads(json_str) \ No newline at end of file + moderation = ModerationFactory(name, app.tenant_id, config) + return moderation.moderation_for_outputs(text).dict() \ No newline at end of file From 26c37adb1e7d39854829f09e35b75faf2f73da74 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Thu, 2 Nov 2023 18:08:12 +0800 Subject: [PATCH 28/57] update. --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/completion.py | 15 ++++++++++++ api/controllers/console/explore/completion.py | 13 +++++++++- api/controllers/console/moderation.py | 24 ------------------- api/controllers/service_api/app/completion.py | 12 +++++++++- api/controllers/web/completion.py | 12 ++++++++++ api/core/moderation/base.py | 1 - api/services/moderation_service.py | 16 +++++-------- 8 files changed, 57 insertions(+), 38 deletions(-) delete mode 100644 api/controllers/console/moderation.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 503563ca957f12..ac881dc126c0d0 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -6,7 +6,7 @@ api = ExternalApi(bp) # Import other controllers -from . import extension, setup, version, apikey, admin, moderation +from . import extension, setup, version, apikey, admin # Import app controllers from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 1da7bd8f2ce974..e6ad2f46dfd514 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -23,6 +23,7 @@ from flask_restful import Resource, reqparse from services.completion_service import CompletionService +from services.moderation_service import ModerationService # define completion message api for user @@ -207,8 +208,22 @@ def post(self, app_id, task_id): return {'result': 'success'}, 200 +class ModerationApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, location='json') + parser.add_argument('text', type=str, required=True, location='json') + args = parser.parse_args() + + service = ModerationService() + return service.moderation_for_outputs(_get_app(str(app_id), None), flask_login.current_user, args['text']) api.add_resource(CompletionMessageApi, '/apps//completion-messages') api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') api.add_resource(ChatMessageApi, '/apps//chat-messages') api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') +api.add_resource(ModerationApi, '/apps/moderation') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index bdf1f3b907dfdd..6e8990a81e82f3 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -19,7 +19,7 @@ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService - +from services.moderation_service import ModerationService # define completion api for user class CompletionApi(InstalledAppResource): @@ -175,8 +175,19 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') +class ModerationApi(InstalledAppResource): + + def post(self, installed_app): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, location='json') + parser.add_argument('text', type=str, required=True, location='json') + args = parser.parse_args() + + service = ModerationService() + return service.moderation_for_outputs(installed_app.app, current_user, args['text']) api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') +api.add_resource(ModerationApi, '/installed-apps/moderation') \ No newline at end of file diff --git a/api/controllers/console/moderation.py b/api/controllers/console/moderation.py deleted file mode 100644 index b3c4374ae9cc0e..00000000000000 --- a/api/controllers/console/moderation.py +++ /dev/null @@ -1,24 +0,0 @@ -from flask_restful import Resource, reqparse -from flask_login import current_user - -from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required -from libs.login import login_required -from services.moderation_service import ModerationService - -class ModerationAPI(Resource): - - @setup_required - @login_required - @account_initialization_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, location='json') - parser.add_argument('text', type=str, required=True, location='json') - args = parser.parse_args() - - service = ModerationService() - return service.moderation_for_outputs(args['app_id'], args['text']) - -api.add_resource(ModerationAPI, '/moderation') \ No newline at end of file diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index a339322ea89a8e..3729471a93b059 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -18,6 +18,7 @@ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService +from services.moderation_service import ModerationService class CompletionApi(AppApiResource): @@ -178,9 +179,18 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') +class ModerationAPI(AppApiResource): + def post(self, app_model, end_user): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, location='json') + parser.add_argument('text', type=str, required=True, location='json') + args = parser.parse_args() + + service = ModerationService() + return service.moderation_for_outputs(app_model, end_user, args['text']) api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') api.add_resource(ChatStopApi, '/chat-messages//stop') - +api.add_resource(ModerationAPI, '/moderation') \ No newline at end of file diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 79c0c542d17852..efe02ddad0368e 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -18,6 +18,7 @@ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService +from services.moderation_service import ModerationService # define completion api for user @@ -172,8 +173,19 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') +class ModerationAPI(WebApiResource): + + def post(self, app_model, end_user): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, location='json') + parser.add_argument('text', type=str, required=True, location='json') + args = parser.parse_args() + + service = ModerationService() + return service.moderation_for_outputs(app_model, end_user, args['text']) api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(ModerationAPI, '/moderation') diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index ccb8a7caf58572..3e4de36ca95ed3 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional from pydantic import BaseModel -from enum import Enum from core.extension.extensible import Extensible, ExtensionModule diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index ec23f4dc7b1d2c..0ab0b42e7a1ae0 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,25 +1,21 @@ -import json +from typing import Union +from models.model import AppModelConfig, App, Account, EndUser from core.moderation.factory import ModerationFactory from extensions.ext_database import db -from models.model import AppModelConfig, App class ModerationService: - def moderation_for_outputs(self, app_id: str, text: str) -> dict: + def moderation_for_outputs(self, app_model: App, user: Union[Account , EndUser], text: str) -> dict: + app_model_config: AppModelConfig = None - app = db.session.query(App).filter(App.id == app_id).first() + app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() - if not app: - raise ValueError("app not found") - - app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app.app_model_config_id).first() - if not app_model_config: raise ValueError("app model config not found") name = app_model_config.sensitive_word_avoidance_dict['type'] config = app_model_config.sensitive_word_avoidance_dict['configs'] - moderation = ModerationFactory(name, app.tenant_id, config) + moderation = ModerationFactory(name, user.tenant_id, config) return moderation.moderation_for_outputs(text).dict() \ No newline at end of file From 8fdcee45461dab4cc4f5e6508503e356338edf23 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 10:13:52 +0800 Subject: [PATCH 29/57] update. --- api/controllers/console/explore/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 6e8990a81e82f3..e917f8f9b0f004 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -190,4 +190,4 @@ def post(self, installed_app): api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') -api.add_resource(ModerationApi, '/installed-apps/moderation') \ No newline at end of file +api.add_resource(ModerationApi, '/installed-apps/moderation', endpoint='installed_app_moderation') \ No newline at end of file From a86fc23f7698f4bf6b25be05eca95da4a095e4b6 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 10:25:52 +0800 Subject: [PATCH 30/57] update. --- api/controllers/console/app/completion.py | 4 ++-- api/services/moderation_service.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index e6ad2f46dfd514..ca6cdd6cd60861 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -213,14 +213,14 @@ class ModerationApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + def post(self): parser = reqparse.RequestParser() parser.add_argument('app_id', type=str, required=True, location='json') parser.add_argument('text', type=str, required=True, location='json') args = parser.parse_args() service = ModerationService() - return service.moderation_for_outputs(_get_app(str(app_id), None), flask_login.current_user, args['text']) + return service.moderation_for_outputs(_get_app(str(args['app_id']), None), flask_login.current_user, args['text']) api.add_resource(CompletionMessageApi, '/apps//completion-messages') api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index 0ab0b42e7a1ae0..dab7b3f4a456e5 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -17,5 +17,5 @@ def moderation_for_outputs(self, app_model: App, user: Union[Account , EndUser], name = app_model_config.sensitive_word_avoidance_dict['type'] config = app_model_config.sensitive_word_avoidance_dict['configs'] - moderation = ModerationFactory(name, user.tenant_id, config) + moderation = ModerationFactory(name, user.current_tenant_id, config) return moderation.moderation_for_outputs(text).dict() \ No newline at end of file From 98fa9d5936092f31fba4168310f6d7dd68e97ada Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 11:56:23 +0800 Subject: [PATCH 31/57] refactor. --- api/controllers/console/app/completion.py | 2 +- api/controllers/console/explore/completion.py | 4 ++-- api/controllers/service_api/app/completion.py | 3 ++- api/controllers/web/completion.py | 2 +- api/services/moderation_service.py | 8 +++----- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index ca6cdd6cd60861..ef41d8e5beae8b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -220,7 +220,7 @@ def post(self): args = parser.parse_args() service = ModerationService() - return service.moderation_for_outputs(_get_app(str(args['app_id']), None), flask_login.current_user, args['text']) + return service.moderation_for_outputs(_get_app(str(args['app_id']), None), args['text']) api.add_resource(CompletionMessageApi, '/apps//completion-messages') api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index e917f8f9b0f004..991f991de96bd8 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -184,10 +184,10 @@ def post(self, installed_app): args = parser.parse_args() service = ModerationService() - return service.moderation_for_outputs(installed_app.app, current_user, args['text']) + return service.moderation_for_outputs(installed_app.app, args['text']) api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') -api.add_resource(ModerationApi, '/installed-apps/moderation', endpoint='installed_app_moderation') \ No newline at end of file +api.add_resource(ModerationApi, '/installed-apps//moderation', endpoint='installed_app_moderation') \ No newline at end of file diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3729471a93b059..9be8a10e7b9c1f 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -180,6 +180,7 @@ def generate() -> Generator: mimetype='text/event-stream') class ModerationAPI(AppApiResource): + def post(self, app_model, end_user): parser = reqparse.RequestParser() parser.add_argument('app_id', type=str, required=True, location='json') @@ -187,7 +188,7 @@ def post(self, app_model, end_user): args = parser.parse_args() service = ModerationService() - return service.moderation_for_outputs(app_model, end_user, args['text']) + return service.moderation_for_outputs(app_model, args['text']) api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index efe02ddad0368e..0410e697337664 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -182,7 +182,7 @@ def post(self, app_model, end_user): args = parser.parse_args() service = ModerationService() - return service.moderation_for_outputs(app_model, end_user, args['text']) + return service.moderation_for_outputs(app_model, args['text']) api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index dab7b3f4a456e5..6e506e5d735f58 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,12 +1,10 @@ -from typing import Union - -from models.model import AppModelConfig, App, Account, EndUser +from models.model import AppModelConfig, App from core.moderation.factory import ModerationFactory from extensions.ext_database import db class ModerationService: - def moderation_for_outputs(self, app_model: App, user: Union[Account , EndUser], text: str) -> dict: + def moderation_for_outputs(self, app_model: App, text: str) -> dict: app_model_config: AppModelConfig = None app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() @@ -17,5 +15,5 @@ def moderation_for_outputs(self, app_model: App, user: Union[Account , EndUser], name = app_model_config.sensitive_word_avoidance_dict['type'] config = app_model_config.sensitive_word_avoidance_dict['configs'] - moderation = ModerationFactory(name, user.current_tenant_id, config) + moderation = ModerationFactory(name, app_model.tenant_id, config) return moderation.moderation_for_outputs(text).dict() \ No newline at end of file From e47dbca6aa586f6aa8b8ceaaecabe7d7ce686b30 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 13:58:14 +0800 Subject: [PATCH 32/57] update. --- api/controllers/console/explore/parameter.py | 4 +++- api/controllers/service_api/app/app.py | 4 +++- api/controllers/web/app.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fb4ce33209ec71..9514834a2be937 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -27,6 +27,7 @@ class AppParameterApi(InstalledAppResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, + 'sensitive_word_avoidance': fields.Raw } @marshal_with(parameters_fields) @@ -42,7 +43,8 @@ def get(self, installed_app: InstalledApp): 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list + 'user_input_form': app_model_config.user_input_form_list, + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict } diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 86b8642571108b..cca4c1addba504 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -28,6 +28,7 @@ class AppParameterApi(AppApiResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, + 'sensitive_word_avoidance': fields.Raw } @marshal_with(parameters_fields) @@ -42,7 +43,8 @@ def get(self, app_model: App, end_user): 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list + 'user_input_form': app_model_config.user_input_form_list, + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict } diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cffda04eea66d5..bb99e26ad165ac 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -27,6 +27,7 @@ class AppParameterApi(WebApiResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, + 'sensitive_word_avoidance': fields.Raw } @marshal_with(parameters_fields) @@ -41,7 +42,8 @@ def get(self, app_model: App, end_user): 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list + 'user_input_form': app_model_config.user_input_form_list, + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict } From 348eebbb7c09eee3d84ddb1d41fde8e4b2ae192e Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 15:08:06 +0800 Subject: [PATCH 33/57] update. --- api/fields/api_based_extension_fields.py | 2 +- api/models/api_based_extension.py | 1 + api/services/api_based_extension_service.py | 39 ++++++++++++++++----- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index e9319029d3175f..6dc0a0b823bc3d 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -5,7 +5,7 @@ class HiddenAPIKey(fields.Raw): def output(self, key, obj): - return obj.api_key[:8] + '***' + obj.api_key[-8:] + return obj.api_key[:3] + '***' + obj.api_key[-3:] api_based_extension_fields = { 'id': fields.String, diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 9468f0897ce148..dc19f6d751d971 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -7,6 +7,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' + PING = 'ping' class APIBasedExtension(db.Model): diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 4c42212f64ed94..d1809c17ab6da4 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,18 +1,24 @@ from extensions.ext_database import db -from models.api_based_extension import APIBasedExtension -from core.helper.encrypter import encrypt_token +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from core.helper.encrypter import encrypt_token, decrypt_token +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor class APIBasedExtensionService: @staticmethod def get_all_by_tenant_id(tenant_id: str): - return db.session.query(APIBasedExtension) \ + extension_list = db.session.query(APIBasedExtension) \ .filter_by(tenant_id=tenant_id) \ .order_by(APIBasedExtension.created_at.desc()) \ .all() + + for extension in extension_list: + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) - @staticmethod - def save(extension_data: APIBasedExtension, need_encrypt: bool) -> APIBasedExtension: + return extension_list + + @classmethod + def save(cls, extension_data: APIBasedExtension, need_encrypt: bool) -> APIBasedExtension: # name if not extension_data.name: raise ValueError("name must not be empty") @@ -45,6 +51,11 @@ def save(extension_data: APIBasedExtension, need_encrypt: bool) -> APIBasedExten if not extension_data.api_key: raise ValueError("api_key must not be empty") + if len(extension_data.api_key) < 5: + raise ValueError("api_key must be at least 5 characters") + + cls._ping_connection(extension_data) + if need_encrypt: extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) @@ -59,12 +70,24 @@ def delete(extension_data: APIBasedExtension) -> None: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - api_based_extension = db.session.query(APIBasedExtension) \ + extension = db.session.query(APIBasedExtension) \ .filter_by(tenant_id=tenant_id) \ .filter_by(id=api_based_extension_id) \ .first() - if not api_based_extension: + if not extension: raise ValueError("API based extension is not found") - return api_based_extension \ No newline at end of file + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) + + return extension + + @staticmethod + def _ping_connection(extension_data: APIBasedExtension) -> None: + try: + client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) + resp = client.request(point=APIBasedExtensionPoint.PING, params={}) + if resp.get('result') != 'pong': + raise ValueError(resp) + except Exception as e: + raise ValueError("connection error: {}".format(e)) \ No newline at end of file From 89b7cb21c3bcea9ab56c40ef009fa6fda618a4e6 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 3 Nov 2023 15:33:59 +0800 Subject: [PATCH 34/57] feat: optimize APIBasedExtensionRequestor error message --- api/core/extension/api_based_extension_requestor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index d35acd95bea317..8ce7edabf23c47 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -54,6 +54,9 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}".format(response.status_code)) + raise ValueError("request error, status_code: {}, content: {}".format( + response.status_code, + response.text[:100] + )) return response.json() From e1f44424e904b9dcb6acdc460f07cae7337200e6 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 15:56:04 +0800 Subject: [PATCH 35/57] update. --- api/controllers/console/extension.py | 1 + api/services/api_based_extension_service.py | 59 +++++++++++---------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 30d38f6bf466f3..c2feff6d716da9 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -91,6 +91,7 @@ def post(self, id): need_encrypt = False if args['api_key'] != '[__HIDDEN__]': need_encrypt = True + extension_data_from_db.api_key = args['api_key'] return APIBasedExtensionService.save(extension_data_from_db, need_encrypt) diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index d1809c17ab6da4..53e08de47a676f 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -6,7 +6,7 @@ class APIBasedExtensionService: @staticmethod - def get_all_by_tenant_id(tenant_id: str): + def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: extension_list = db.session.query(APIBasedExtension) \ .filter_by(tenant_id=tenant_id) \ .order_by(APIBasedExtension.created_at.desc()) \ @@ -19,6 +19,36 @@ def get_all_by_tenant_id(tenant_id: str): @classmethod def save(cls, extension_data: APIBasedExtension, need_encrypt: bool) -> APIBasedExtension: + cls._validation(extension_data) + + if need_encrypt: + extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) + + db.session.add(extension_data) + db.session.commit() + return extension_data + + @staticmethod + def delete(extension_data: APIBasedExtension) -> None: + db.session.delete(extension_data) + db.session.commit() + + @staticmethod + def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=tenant_id) \ + .filter_by(id=api_based_extension_id) \ + .first() + + if not extension: + raise ValueError("API based extension is not found") + + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) + + return extension + + @classmethod + def _validation(cls, extension_data: APIBasedExtension) -> None: # name if not extension_data.name: raise ValueError("name must not be empty") @@ -54,33 +84,8 @@ def save(cls, extension_data: APIBasedExtension, need_encrypt: bool) -> APIBased if len(extension_data.api_key) < 5: raise ValueError("api_key must be at least 5 characters") + # check endpoint cls._ping_connection(extension_data) - - if need_encrypt: - extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) - - db.session.add(extension_data) - db.session.commit() - return extension_data - - @staticmethod - def delete(extension_data: APIBasedExtension) -> None: - db.session.delete(extension_data) - db.session.commit() - - @staticmethod - def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .filter_by(id=api_based_extension_id) \ - .first() - - if not extension: - raise ValueError("API based extension is not found") - - extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) - - return extension @staticmethod def _ping_connection(extension_data: APIBasedExtension) -> None: From 82120da2afed021b6a7e848d06f3538279f4d05d Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 16:35:35 +0800 Subject: [PATCH 36/57] remove sample code. --- api/core/moderation/cloud_service/__init__.py | 0 .../moderation/cloud_service/cloud_service.py | 13 ----- api/core/moderation/cloud_service/schema.json | 47 ------------------- 3 files changed, 60 deletions(-) delete mode 100644 api/core/moderation/cloud_service/__init__.py delete mode 100644 api/core/moderation/cloud_service/cloud_service.py delete mode 100644 api/core/moderation/cloud_service/schema.json diff --git a/api/core/moderation/cloud_service/__init__.py b/api/core/moderation/cloud_service/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/api/core/moderation/cloud_service/cloud_service.py b/api/core/moderation/cloud_service/cloud_service.py deleted file mode 100644 index ced2b118b333fa..00000000000000 --- a/api/core/moderation/cloud_service/cloud_service.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Optional - -from core.moderation.base import Moderation - - -class CloudServiceModeration(Moderation): - name: str = "cloud_service" - - def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - pass - - def moderation_for_outputs(self, text: str): - pass \ No newline at end of file diff --git a/api/core/moderation/cloud_service/schema.json b/api/core/moderation/cloud_service/schema.json deleted file mode 100644 index 3e05822ff60cd2..00000000000000 --- a/api/core/moderation/cloud_service/schema.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "label": { - "en-US": "Cloud Service", - "zh-Hans": "云服务" - }, - "form_schema": [ - { - "type": "select", - "label": { - "en-US": "Cloud Provider", - "zh-Hans": "云计算厂商" - }, - "variable": "cloud_provider", - "required": true, - "options": [ - "腾讯云", - "阿里云", - "AWS" - ], - "default": "", - "placeholder": "" - }, - { - "type": "text-input", - "label": { - "en-US": "API Endpoint", - "zh-Hans": "API Endpoint" - }, - "variable": "api_endpoint", - "required": true, - "max_length": 100, - "default": "", - "placeholder": "" - }, - { - "type": "paragraph", - "label": { - "en-US": "API Key", - "zh-Hans": "API Key" - }, - "variable": "api_keys", - "required": true, - "default": "", - "placeholder": "" - } - ] -} \ No newline at end of file From 84af6454444d183f6992a9a5b81d3cb42ee7dbc0 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 17:41:10 +0800 Subject: [PATCH 37/57] remove s. --- api/core/completion.py | 2 +- api/core/moderation/base.py | 32 +++++++++---------- api/core/moderation/keywords/keywords.py | 8 ++--- .../openai_moderation/openai_moderation.py | 8 ++--- api/services/app_model_config_service.py | 2 +- api/services/moderation_service.py | 2 +- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/api/core/completion.py b/api/core/completion.py index 909a955c0e1881..2e85af8dff8a03 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -165,7 +165,7 @@ def moderation_for_inputs(cls, tenant_id: str, app_model_config: AppModelConfig, type = app_model_config.sensitive_word_avoidance_dict['type'] - moderation = ModerationFactory(type, tenant_id, app_model_config.sensitive_word_avoidance_dict['configs']) + moderation = ModerationFactory(type, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) moderation.moderation_for_inputs(inputs, query) @classmethod diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 3e4de36ca95ed3..1884ccecdae9ca 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -57,30 +57,30 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: @classmethod def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: - # inputs_configs - inputs_configs = config.get("inputs_configs") - if not isinstance(inputs_configs, dict): - raise ValueError("inputs_configs must be a dict") + # inputs_config + inputs_config = config.get("inputs_config") + if not isinstance(inputs_config, dict): + raise ValueError("inputs_config must be a dict") - # outputs_configs - outputs_configs = config.get("outputs_configs") - if not isinstance(outputs_configs, dict): - raise ValueError("outputs_configs must be a dict") + # outputs_config + outputs_config = config.get("outputs_config") + if not isinstance(outputs_config, dict): + raise ValueError("outputs_config must be a dict") - inputs_configs_enabled = inputs_configs.get("enabled") - outputs_configs_enabled = outputs_configs.get("enabled") - if not inputs_configs_enabled and not outputs_configs_enabled: - raise ValueError("At least one of inputs_configs or outputs_configs must be enabled") + inputs_config_enabled = inputs_config.get("enabled") + outputs_config_enabled = outputs_config.get("enabled") + if not inputs_config_enabled and not outputs_config_enabled: + raise ValueError("At least one of inputs_config or outputs_config must be enabled") # preset_response if not is_preset_response_required: return - if inputs_configs_enabled and not inputs_configs.get("preset_response"): - raise ValueError("inputs_configs.preset_response is required") + if inputs_config_enabled and not inputs_config.get("preset_response"): + raise ValueError("inputs_config.preset_response is required") - if outputs_configs_enabled and not outputs_configs.get("preset_response"): - raise ValueError("outputs_configs.preset_response is required") + if outputs_config_enabled and not outputs_config.get("preset_response"): + raise ValueError("outputs_config.preset_response is required") class ModerationException(Exception): pass \ No newline at end of file diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index b9a3e128b02421..7326afc979f346 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -21,14 +21,14 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: raise ValueError("keywords is required") def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - if not self.config['inputs_configs']['enabled']: + if not self.config['inputs_config']['enabled']: return if query: inputs['query__'] = query keywords_list = self.config['keywords'].split('\n') - preset_response = self.config['inputs_configs']['preset_response'] + preset_response = self.config['inputs_config']['preset_response'] if self._is_violated(inputs, keywords_list): raise ModerationException(preset_response) @@ -37,11 +37,11 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_configs']['enabled']: + if self.config['outputs_config']['enabled']: keywords_list = self.config['keywords'].split('\n') flagged = self._is_violated({'text': text}, keywords_list) if flagged: - preset_response = self.config['outputs_configs']['preset_response'] + preset_response = self.config['outputs_config']['preset_response'] return ModerationOutputsResult(flagged=flagged, text=preset_response) diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 01541341d15bc8..3300fc623e71fd 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -23,10 +23,10 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: cls._validate_inputs_and_outputs_config(config, True) def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - if not self.config['inputs_configs']['enabled']: + if not self.config['inputs_config']['enabled']: return - preset_response = self.config['inputs_configs']['preset_response'] + preset_response = self.config['inputs_config']['preset_response'] if query: inputs['query__'] = query @@ -37,10 +37,10 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_configs']['enabled']: + if self.config['outputs_config']['enabled']: flagged = self._is_violated({ 'text': text }) if flagged: - preset_response = self.config['outputs_configs']['preset_response'] + preset_response = self.config['outputs_config']['preset_response'] return ModerationOutputsResult(flagged=flagged, text=preset_response) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 9f62734b007b5d..64809c8fc604d9 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -368,7 +368,7 @@ def is_moderation_valid(cls, tenant_id: str, config: dict): raise ValueError("sensitive_word_avoidance.type is required") type = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["configs"] + config = config["sensitive_word_avoidance"]["config"] ModerationFactory.validate_config( name=type, diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index 6e506e5d735f58..7c1a36384a180d 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -13,7 +13,7 @@ def moderation_for_outputs(self, app_model: App, text: str) -> dict: raise ValueError("app model config not found") name = app_model_config.sensitive_word_avoidance_dict['type'] - config = app_model_config.sensitive_word_avoidance_dict['configs'] + config = app_model_config.sensitive_word_avoidance_dict['config'] moderation = ModerationFactory(name, app_model.tenant_id, config) return moderation.moderation_for_outputs(text).dict() \ No newline at end of file From f67c2dadc04b1b2c82ffe84a6d9c97101f07676c Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 3 Nov 2023 17:19:51 +0800 Subject: [PATCH 38/57] feat: add moderation output logic --- .../callback_handler/llm_callback_handler.py | 88 +++++++++++++++++-- api/core/moderation/base.py | 9 ++ 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index b8eb99b2e5bfc3..cd8a0bda1a1eaf 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -3,11 +3,19 @@ from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult, BaseMessage +from pydantic import BaseModel from core.callback_handler.entity.llm_message import LLMMessage from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage from core.model_providers.models.llm.base import BaseLLM +from core.moderation.base import ModerationOutputsResult, ModerationOutputsAction +from core.moderation.factory import ModerationFactory + + +class ModerationRule(BaseModel): + type: str + config: Dict[str, Any] class LLMCallbackHandler(BaseCallbackHandler): @@ -20,6 +28,19 @@ def __init__(self, model_instance: BaseLLM, self.start_at = None self.conversation_message_task = conversation_message_task + app_model_config = self.conversation_message_task.app_model_config + sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict + + self.is_interrupt = False + self.moderation_rule = None + self.moderation_buffer = '' + self.moderation_chunk = '' + if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"): + self.moderation_rule = ModerationRule( + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config") + ) + @property def always_verbose(self) -> bool: """Whether to call verbose callbacks even if verbose is False.""" @@ -60,8 +81,11 @@ def on_llm_start( def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: if not self.conversation_message_task.streaming: - self.conversation_message_task.append_message_text(response.generations[0][0].text) - self.llm_message.completion = response.generations[0][0].text + moderation_result = self.moderation_completion(response.generations[0][0].text) + if not moderation_result: + self.llm_message.completion = response.generations[0][0].text + + self.conversation_message_task.append_message_text(self.llm_message.completion) if response.llm_output and 'token_usage' in response.llm_output: if 'prompt_tokens' in response.llm_output['token_usage']: @@ -79,13 +103,16 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.conversation_message_task.save_message(self.llm_message) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - try: - self.conversation_message_task.append_message_text(token) - except ConversationTaskStoppedException as ex: - self.on_llm_error(error=ex) - raise ex + self.moderation_completion(token) + + if not self.is_interrupt: + try: + self.conversation_message_task.append_message_text(token) + except ConversationTaskStoppedException as ex: + self.on_llm_error(error=ex) + raise ex - self.llm_message.completion += token + self.llm_message.completion += token def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @@ -99,3 +126,48 @@ def on_llm_error( self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) else: logging.debug("on_llm_error: %s", error) + + def moderation_completion(self, token: str) -> bool: + """ + Moderation for outputs. + + :param token: LLM output content + :return: bool + """ + if not self.moderation_rule: + return False + + if len(self.moderation_chunk) < 50: + self.moderation_chunk += token + return False + + moderation_chunk = self.moderation_chunk + self.moderation_chunk = '' + + try: + moderation_factory = ModerationFactory( + name=self.moderation_rule.type, + tenant_id=self.conversation_message_task.tenant_id, + config=self.moderation_rule.config + ) + + result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_chunk) + if not result.flagged: + return False + + if result.action == ModerationOutputsAction.DIRECT_OUTPUT: + self.is_interrupt = True + self.llm_message.completion = result.text + else: + self.llm_message.completion = self.moderation_buffer + moderation_chunk + self.moderation_chunk + + if self.conversation_message_task.streaming: + # TODO trigger replace event + logging.debug("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) + except Exception as e: + logging.error("Moderation Output error: %s", e) + return False + finally: + self.moderation_buffer += moderation_chunk + + return True diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 1884ccecdae9ca..c4ff7962546445 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,14 +1,23 @@ from abc import ABC, abstractmethod from typing import Optional from pydantic import BaseModel +from enum import Enum from core.extension.extensible import Extensible, ExtensionModule +class ModerationOutputsAction(Enum): + DIRECT_OUTPUT = 'direct_output' + OVERRIDE = 'override' + + class ModerationOutputsResult(BaseModel): flagged: bool = False + action: ModerationOutputsAction + preset_response: str = "" text: str = "" + class Moderation(Extensible, ABC): """ The base class of moderation. From 95e2d84f69136bb6b818015e4047580e959bde98 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 18:01:31 +0800 Subject: [PATCH 39/57] refactor. --- api/core/moderation/base.py | 2 +- api/core/moderation/keywords/keywords.py | 7 +++---- api/core/moderation/openai_moderation/openai_moderation.py | 7 +++---- api/services/moderation_service.py | 5 ++++- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index c4ff7962546445..41c367a03a821c 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -8,7 +8,7 @@ class ModerationOutputsAction(Enum): DIRECT_OUTPUT = 'direct_output' - OVERRIDE = 'override' + OVERRIDED = 'overrided' class ModerationOutputsResult(BaseModel): diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 7326afc979f346..d6d21606583671 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,6 +1,6 @@ from typing import Optional -from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult +from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOutputsAction class KeywordsModeration(Moderation): @@ -40,10 +40,9 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: if self.config['outputs_config']['enabled']: keywords_list = self.config['keywords'].split('\n') flagged = self._is_violated({'text': text}, keywords_list) - if flagged: - preset_response = self.config['outputs_config']['preset_response'] + preset_response = self.config['outputs_config']['preset_response'] - return ModerationOutputsResult(flagged=flagged, text=preset_response) + return ModerationOutputsResult(flagged=flagged, action=ModerationOutputsAction.DIRECT_OUTPUT, preset_response=preset_response) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 3300fc623e71fd..491779427f46f5 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -3,7 +3,7 @@ from typing import Optional from core.helper.encrypter import decrypt_token -from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult +from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOutputsAction from extensions.ext_database import db from models.provider import Provider @@ -39,10 +39,9 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: if self.config['outputs_config']['enabled']: flagged = self._is_violated({ 'text': text }) - if flagged: - preset_response = self.config['outputs_config']['preset_response'] + preset_response = self.config['outputs_config']['preset_response'] - return ModerationOutputsResult(flagged=flagged, text=preset_response) + return ModerationOutputsResult(flagged=flagged, action=ModerationOutputsAction.DIRECT_OUTPUT, preset_response=preset_response) def _is_violated(self, inputs: dict): diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index 7c1a36384a180d..58b979b591ebbf 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,3 +1,5 @@ +import json + from models.model import AppModelConfig, App from core.moderation.factory import ModerationFactory from extensions.ext_database import db @@ -16,4 +18,5 @@ def moderation_for_outputs(self, app_model: App, text: str) -> dict: config = app_model_config.sensitive_word_avoidance_dict['config'] moderation = ModerationFactory(name, app_model.tenant_id, config) - return moderation.moderation_for_outputs(text).dict() \ No newline at end of file + data = moderation.moderation_for_outputs(text).json() + return json.loads(data) \ No newline at end of file From be3d36280643757296b2fc47eeeb3e5e607d3f09 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 3 Nov 2023 18:03:46 +0800 Subject: [PATCH 40/57] feat: add message replace event --- .../callback_handler/llm_callback_handler.py | 7 +++--- api/core/conversation_message_task.py | 22 +++++++++++++++++++ api/services/completion_service.py | 17 ++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index cd8a0bda1a1eaf..316e28d0d0e3d7 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -157,13 +157,14 @@ def moderation_completion(self, token: str) -> bool: if result.action == ModerationOutputsAction.DIRECT_OUTPUT: self.is_interrupt = True - self.llm_message.completion = result.text + self.llm_message.completion = result.preset_response else: - self.llm_message.completion = self.moderation_buffer + moderation_chunk + self.moderation_chunk + self.llm_message.completion = self.moderation_buffer + result.text + self.moderation_chunk if self.conversation_message_task.streaming: - # TODO trigger replace event + # trigger replace event logging.debug("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) + self.conversation_message_task.on_message_replace(self.llm_message.completion) except Exception as e: logging.error("Moderation Output error: %s", e) return False diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 3be6ffaee37bb0..0358d5cd3dec0e 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -290,6 +290,10 @@ def on_dataset_query_finish(self, resource: List): db.session.commit() self.retriever_resource = resource + def on_message_replace(self, text: str): + if text is not None: + self._pub_handler.pub_message_replace(text) + def message_end(self): self._pub_handler.pub_message_end(self.retriever_resource) @@ -342,6 +346,24 @@ def pub_text(self, text: str): self.pub_end() raise ConversationTaskStoppedException() + def pub_message_replace(self, text: str): + content = { + 'event': 'message_replace', + 'data': { + 'task_id': self._task_id, + 'message_id': str(self._message.id), + 'text': text, + 'mode': self._conversation.mode, + 'conversation_id': str(self._conversation.id) + } + } + + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() + def pub_chain(self, message_chain: MessageChain): if self._chain_pub: content = { diff --git a/api/services/completion_service.py b/api/services/completion_service.py index df12a88718462a..29575d2dab3e22 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -386,6 +386,8 @@ def generate() -> Generator: break if event == 'message': yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" + elif event == 'message_replace': + yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n" elif event == 'chain': yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" elif event == 'agent_thought': @@ -427,6 +429,21 @@ def get_message_response_data(cls, data: dict): return response_data + @classmethod + def get_message_replace_response_data(cls, data: dict): + response_data = { + 'event': 'message_replace', + 'task_id': data.get('task_id'), + 'id': data.get('message_id'), + 'answer': data.get('text'), + 'created_at': int(time.time()) + } + + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + @classmethod def get_blocking_message_response_data(cls, data: dict): message = data.get('message') From fdaf63c9a0fce4ddcec8ff77f209cedb94f6b16d Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 3 Nov 2023 18:42:06 +0800 Subject: [PATCH 41/57] feat: fix moderation still stream output --- .../callback_handler/llm_callback_handler.py | 43 +++++++++++++------ api/core/completion.py | 5 ++- api/core/conversation_message_task.py | 4 ++ api/services/completion_service.py | 5 ++- 4 files changed, 39 insertions(+), 18 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 316e28d0d0e3d7..aae10c707c2c2e 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -6,7 +6,8 @@ from pydantic import BaseModel from core.callback_handler.entity.llm_message import LLMMessage -from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ + ConversationTaskInterruptException from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage from core.model_providers.models.llm.base import BaseLLM from core.moderation.base import ModerationOutputsResult, ModerationOutputsAction @@ -31,7 +32,6 @@ def __init__(self, model_instance: BaseLLM, app_model_config = self.conversation_message_task.app_model_config sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict - self.is_interrupt = False self.moderation_rule = None self.moderation_buffer = '' self.moderation_chunk = '' @@ -81,11 +81,13 @@ def on_llm_start( def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: if not self.conversation_message_task.streaming: - moderation_result = self.moderation_completion(response.generations[0][0].text) + moderation_result = self.moderation_completion(response.generations[0][0].text, True) if not moderation_result: self.llm_message.completion = response.generations[0][0].text self.conversation_message_task.append_message_text(self.llm_message.completion) + else: + self.moderation_completion(self.llm_message.completion, True) if response.llm_output and 'token_usage' in response.llm_output: if 'prompt_tokens' in response.llm_output['token_usage']: @@ -105,14 +107,13 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.moderation_completion(token) - if not self.is_interrupt: - try: - self.conversation_message_task.append_message_text(token) - except ConversationTaskStoppedException as ex: - self.on_llm_error(error=ex) - raise ex + try: + self.conversation_message_task.append_message_text(token) + except ConversationTaskStoppedException as ex: + self.on_llm_error(error=ex) + raise ex - self.llm_message.completion += token + self.llm_message.completion += token def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @@ -124,10 +125,12 @@ def on_llm_error( [PromptMessage(content=self.llm_message.completion)] ) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) + if isinstance(error, ConversationTaskInterruptException): + pass else: logging.debug("on_llm_error: %s", error) - def moderation_completion(self, token: str) -> bool: + def moderation_completion(self, token: str, no_chunk: bool = False) -> bool: """ Moderation for outputs. @@ -137,9 +140,12 @@ def moderation_completion(self, token: str) -> bool: if not self.moderation_rule: return False - if len(self.moderation_chunk) < 50: - self.moderation_chunk += token - return False + if not no_chunk: + if len(self.moderation_chunk) < 50: + self.moderation_chunk += token + return False + else: + self.moderation_chunk = token moderation_chunk = self.moderation_chunk self.moderation_chunk = '' @@ -165,6 +171,15 @@ def moderation_completion(self, token: str) -> bool: # trigger replace event logging.debug("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) self.conversation_message_task.on_message_replace(self.llm_message.completion) + + if result.action == ModerationOutputsAction.DIRECT_OUTPUT: + self.llm_message.completion_tokens = self.model_instance.get_num_tokens( + [PromptMessage(content=self.llm_message.completion)] + ) + self.conversation_message_task.save_message(llm_message=self.llm_message) + raise ConversationTaskInterruptException() + except ConversationTaskInterruptException as e: + raise e except Exception as e: logging.error("Moderation Output error: %s", e) return False diff --git a/api/core/completion.py b/api/core/completion.py index 2e85af8dff8a03..920924424deee9 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -10,7 +10,8 @@ from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler -from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ + ConversationTaskInterruptException from core.external_data_tool.factory import ExternalDataToolFactory from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ @@ -150,7 +151,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer memory=memory, fake_response=fake_response ) - except ConversationTaskStoppedException: + except (ConversationTaskInterruptException, ConversationTaskStoppedException): return except ChunkedEncodingError as e: # Interrupt by LLM (like OpenAI), handle it. diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 0358d5cd3dec0e..9dd211d36087df 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -465,3 +465,7 @@ def stop(cls, user: Union[Account | EndUser], task_id: str): class ConversationTaskStoppedException(Exception): pass + + +class ConversationTaskInterruptException(Exception): + pass diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 29575d2dab3e22..a26ba8613f5eaf 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -10,7 +10,8 @@ from sqlalchemy import and_ from core.completion import Completion -from core.conversation_message_task import PubHandler, ConversationTaskStoppedException +from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \ + ConversationTaskInterruptException from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ LLMRateLimitError, \ LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError @@ -199,7 +200,7 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m is_override=is_model_config_override, retriever_from=retriever_from ) - except ConversationTaskStoppedException: + except (ConversationTaskInterruptException, ConversationTaskStoppedException): pass except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, From 3b7e8890d4806526a23d4886ffc7b13d0f28d2e7 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 20:40:57 +0800 Subject: [PATCH 42/57] refactor. --- api/controllers/console/app/completion.py | 15 ---- api/controllers/console/explore/completion.py | 14 +--- api/controllers/service_api/app/completion.py | 14 +--- api/controllers/web/completion.py | 12 --- .../callback_handler/llm_callback_handler.py | 4 +- api/core/completion.py | 25 +++++-- api/core/moderation/api/api.py | 75 +++++++++++++++---- api/core/moderation/base.py | 17 ++++- api/core/moderation/factory.py | 8 +- api/core/moderation/keywords/keywords.py | 25 ++++--- .../openai_moderation/openai_moderation.py | 22 +++--- api/models/api_based_extension.py | 2 + api/services/moderation_service.py | 8 +- 13 files changed, 133 insertions(+), 108 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index ef41d8e5beae8b..c3e4837269265e 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -208,22 +208,7 @@ def post(self, app_id, task_id): return {'result': 'success'}, 200 -class ModerationApi(Resource): - - @setup_required - @login_required - @account_initialization_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, location='json') - parser.add_argument('text', type=str, required=True, location='json') - args = parser.parse_args() - - service = ModerationService() - return service.moderation_for_outputs(_get_app(str(args['app_id']), None), args['text']) - api.add_resource(CompletionMessageApi, '/apps//completion-messages') api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') api.add_resource(ChatMessageApi, '/apps//chat-messages') api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') -api.add_resource(ModerationApi, '/apps/moderation') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 991f991de96bd8..b0a9a959f8163f 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -175,19 +175,7 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') -class ModerationApi(InstalledAppResource): - - def post(self, installed_app): - parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, location='json') - parser.add_argument('text', type=str, required=True, location='json') - args = parser.parse_args() - - service = ModerationService() - return service.moderation_for_outputs(installed_app.app, args['text']) - api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') -api.add_resource(ModerationApi, '/installed-apps//moderation', endpoint='installed_app_moderation') \ No newline at end of file +api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') \ No newline at end of file diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 9be8a10e7b9c1f..2ec089d8c543f9 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -179,19 +179,7 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') -class ModerationAPI(AppApiResource): - - def post(self, app_model, end_user): - parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, location='json') - parser.add_argument('text', type=str, required=True, location='json') - args = parser.parse_args() - - service = ModerationService() - return service.moderation_for_outputs(app_model, args['text']) - api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') -api.add_resource(ModerationAPI, '/moderation') \ No newline at end of file +api.add_resource(ChatStopApi, '/chat-messages//stop') \ No newline at end of file diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0410e697337664..315430e4db4110 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -173,19 +173,7 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') -class ModerationAPI(WebApiResource): - - def post(self, app_model, end_user): - parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, location='json') - parser.add_argument('text', type=str, required=True, location='json') - args = parser.parse_args() - - service = ModerationService() - return service.moderation_for_outputs(app_model, args['text']) - api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') api.add_resource(ChatStopApi, '/chat-messages//stop') -api.add_resource(ModerationAPI, '/moderation') diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index aae10c707c2c2e..3cc025890e02e2 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -10,7 +10,7 @@ ConversationTaskInterruptException from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage from core.model_providers.models.llm.base import BaseLLM -from core.moderation.base import ModerationOutputsResult, ModerationOutputsAction +from core.moderation.base import ModerationOutputsResult, ModerationAction from core.moderation.factory import ModerationFactory @@ -161,7 +161,7 @@ def moderation_completion(self, token: str, no_chunk: bool = False) -> bool: if not result.flagged: return False - if result.action == ModerationOutputsAction.DIRECT_OUTPUT: + if result.action == ModerationAction.DIRECT_OUTPUT: self.is_interrupt = True self.llm_message.completion = result.preset_response else: diff --git a/api/core/completion.py b/api/core/completion.py index 920924424deee9..05ea1955ceac4c 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -23,7 +23,7 @@ from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from models.model import App, AppModelConfig, Account, Conversation, EndUser -from core.moderation.base import ModerationException +from core.moderation.base import ModerationException, ModerationAction from core.moderation.factory import ModerationFactory @@ -87,7 +87,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer try: # process sensitive_word_avoidance - cls.moderation_for_inputs(app.tenant_id, app_model_config, inputs, query) + inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query) except ModerationException as e: cls.run_final_llm( model_instance=final_model_instance, @@ -160,14 +160,27 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer return @classmethod - def moderation_for_inputs(cls, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): + def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str) -> Tuple[dict, str]: if not app_model_config.sensitive_word_avoidance_dict['enabled']: - return + return inputs, query type = app_model_config.sensitive_word_avoidance_dict['type'] - moderation = ModerationFactory(type, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) - moderation.moderation_for_inputs(inputs, query) + moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) + moderation_result = moderation.moderation_for_inputs(inputs, query) + + if not moderation_result.flagged: + return inputs, query + + if moderation_result.action == ModerationAction.DIRECT_OUTPUT: + raise ModerationException(moderation_result.preset_response) + elif moderation_result.action == ModerationAction.OVERRIDED: + inputs = moderation_result.inputs + query = moderation_result.query + + return inputs, query + + @classmethod def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict], diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 6df772c2f66486..42a156ff11b556 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,10 +1,24 @@ -from typing import Optional +from typing import Union +from pydantic import BaseModel -from core.moderation.base import Moderation +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction, ModerationException +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint +from core.helper.encrypter import decrypt_token from extensions.ext_database import db from models.api_based_extension import APIBasedExtension +class ModerationInputParams(BaseModel): + app_id: str = "" + inputs: dict = {} + query: str = "" + + +class ModerationOutputParams(BaseModel): + app_id: str = "" + text: str + + class ApiModeration(Moderation): name: str = "api" @@ -23,20 +37,55 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not api_based_extension_id: raise ValueError("api_based_extension_id is required") - # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() - - if not api_based_extension: + extension = cls._get_api_based_extension(tenant_id, api_based_extension_id) + if not extension: raise ValueError("api_based_extension_id is invalid") - def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - pass + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + flagged = False + preset_response = "" - def moderation_for_outputs(self, text: str): - pass + if self.config['inputs_config']['enabled']: + params = ModerationInputParams( + app_id=self.app_id, + inputs=inputs, + query=query + ) + + result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict()) + return ModerationInputsResult(**result) + + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: + flagged = False + preset_response = "" + if self.config['outputs_config']['enabled']: + params = ModerationOutputParams( + app_id=self.app_id, + text=text + ) + + result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict()) + return ModerationOutputsResult(**result) + + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + + def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) + + result = requestor.request(extension_point, params) + return result + + @staticmethod + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + return extension \ No newline at end of file diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 41c367a03a821c..bffc05d7151334 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -6,14 +6,22 @@ from core.extension.extensible import Extensible, ExtensionModule -class ModerationOutputsAction(Enum): +class ModerationAction(Enum): DIRECT_OUTPUT = 'direct_output' OVERRIDED = 'overrided' +class ModerationInputsResult(BaseModel): + flagged: bool = False + action: ModerationAction + preset_response: str = "" + inputs: dict = {} + query: str = "" + + class ModerationOutputsResult(BaseModel): flagged: bool = False - action: ModerationOutputsAction + action: ModerationAction preset_response: str = "" text: str = "" @@ -24,8 +32,9 @@ class Moderation(Extensible, ABC): """ module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: super().__init__(tenant_id, config) + self.app_id = app_id @classmethod @abstractmethod @@ -40,7 +49,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: raise NotImplementedError @abstractmethod - def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 18a72e1e7c899a..e0ec5c3548c002 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,16 +1,16 @@ from typing import Optional from core.extension.extensible import ExtensionModule -from core.moderation.base import Moderation, ModerationOutputsResult +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult from extensions.ext_code_based_extension import code_based_extension class ModerationFactory: __extension_instance: Moderation - def __init__(self, name: str, tenant_id: str, config: dict) -> None: + def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - self.__extension_instance = extension_class(tenant_id, config) + self.__extension_instance = extension_class(app_id, tenant_id, config) @classmethod def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: @@ -26,7 +26,7 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) extension_class.validate_config(tenant_id, config) - def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index d6d21606583671..d30f044b80bb30 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,6 +1,6 @@ from typing import Optional -from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOutputsAction +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction class KeywordsModeration(Moderation): @@ -20,19 +20,20 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - if not self.config['inputs_config']['enabled']: - return - - if query: - inputs['query__'] = query + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + flagged = False + preset_response = "" - keywords_list = self.config['keywords'].split('\n') - preset_response = self.config['inputs_config']['preset_response'] + if self.config['inputs_config']['enabled']: + preset_response = self.config['inputs_config']['preset_response'] - if self._is_violated(inputs, keywords_list): - raise ModerationException(preset_response) + if query: + inputs['query__'] = query + keywords_list = self.config['keywords'].split('\n') + flagged = self._is_violated(inputs, keywords_list) + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" @@ -42,7 +43,7 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = self._is_violated({'text': text}, keywords_list) preset_response = self.config['outputs_config']['preset_response'] - return ModerationOutputsResult(flagged=flagged, action=ModerationOutputsAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 491779427f46f5..15eb9f9eac4f52 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -3,7 +3,7 @@ from typing import Optional from core.helper.encrypter import decrypt_token -from core.moderation.base import Moderation, ModerationException, ModerationOutputsResult, ModerationOutputsAction +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction from extensions.ext_database import db from models.provider import Provider @@ -22,16 +22,18 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: """ cls._validate_inputs_and_outputs_config(config, True) - def moderation_for_inputs(self, inputs: dict, query: Optional[str] = None): - if not self.config['inputs_config']['enabled']: - return + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + flagged = False + preset_response = "" - preset_response = self.config['inputs_config']['preset_response'] - if query: - inputs['query__'] = query + if self.config['inputs_config']['enabled']: + preset_response = self.config['inputs_config']['preset_response'] - if self._is_violated(inputs): - raise ModerationException(preset_response) + if query: + inputs['query__'] = query + flagged = self._is_violated(inputs) + + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False @@ -41,7 +43,7 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = self._is_violated({ 'text': text }) preset_response = self.config['outputs_config']['preset_response'] - return ModerationOutputsResult(flagged=flagged, action=ModerationOutputsAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) def _is_violated(self, inputs: dict): diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index dc19f6d751d971..28ff868c146ffb 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -8,6 +8,8 @@ class APIBasedExtensionPoint(enum.Enum): APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' PING = 'ping' + APP_MODERATION_INPUT = 'app.moderation.input' + APP_MODERATION_OUTPUT = 'app.moderation.output' class APIBasedExtension(db.Model): diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index 58b979b591ebbf..9886fc8ae99b10 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,12 +1,12 @@ import json from models.model import AppModelConfig, App -from core.moderation.factory import ModerationFactory +from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db class ModerationService: - def moderation_for_outputs(self, app_model: App, text: str) -> dict: + def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: app_model_config: AppModelConfig = None app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() @@ -17,6 +17,6 @@ def moderation_for_outputs(self, app_model: App, text: str) -> dict: name = app_model_config.sensitive_word_avoidance_dict['type'] config = app_model_config.sensitive_word_avoidance_dict['config'] - moderation = ModerationFactory(name, app_model.tenant_id, config) - data = moderation.moderation_for_outputs(text).json() + moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) + data = moderation.moderation_for_outputs(text).json() return json.loads(data) \ No newline at end of file From ffbf9b5228b631f034cb49d52299c35c53f7d28c Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 20:44:40 +0800 Subject: [PATCH 43/57] update. --- api/controllers/console/app/completion.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/web/completion.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index c3e4837269265e..1da7bd8f2ce974 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -23,7 +23,6 @@ from flask_restful import Resource, reqparse from services.completion_service import CompletionService -from services.moderation_service import ModerationService # define completion message api for user @@ -208,6 +207,7 @@ def post(self, app_id, task_id): return {'result': 'success'}, 200 + api.add_resource(CompletionMessageApi, '/apps//completion-messages') api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') api.add_resource(ChatMessageApi, '/apps//chat-messages') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index b0a9a959f8163f..f7642f2ed7e650 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -19,7 +19,6 @@ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService -from services.moderation_service import ModerationService # define completion api for user class CompletionApi(InstalledAppResource): @@ -175,6 +174,7 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') + api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 2ec089d8c543f9..3f2126bcbed149 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -18,7 +18,6 @@ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService -from services.moderation_service import ModerationService class CompletionApi(AppApiResource): @@ -179,6 +178,7 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') + api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 315430e4db4110..79c0c542d17852 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -18,7 +18,6 @@ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService -from services.moderation_service import ModerationService # define completion api for user @@ -173,6 +172,7 @@ def generate() -> Generator: return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') + api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') From f48fd419b44515c87c3b4f58d36f8951bebfa2cd Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 3 Nov 2023 20:51:02 +0800 Subject: [PATCH 44/57] update. --- api/controllers/console/explore/completion.py | 3 ++- api/controllers/service_api/app/completion.py | 2 +- api/services/moderation_service.py | 5 +---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index f7642f2ed7e650..bdf1f3b907dfdd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -20,6 +20,7 @@ from libs.helper import uuid_value from services.completion_service import CompletionService + # define completion api for user class CompletionApi(InstalledAppResource): @@ -178,4 +179,4 @@ def generate() -> Generator: api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') \ No newline at end of file +api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3f2126bcbed149..5ab8a7d116ab4a 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -182,4 +182,4 @@ def generate() -> Generator: api.add_resource(CompletionApi, '/completion-messages') api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') \ No newline at end of file +api.add_resource(ChatStopApi, '/chat-messages//stop') diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index 9886fc8ae99b10..f933c9abc0a0fa 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,5 +1,3 @@ -import json - from models.model import AppModelConfig, App from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db @@ -18,5 +16,4 @@ def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> Mode config = app_model_config.sensitive_word_avoidance_dict['config'] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) - data = moderation.moderation_for_outputs(text).json() - return json.loads(data) \ No newline at end of file + return moderation.moderation_for_outputs(text) \ No newline at end of file From 80441d5d5be51dd7eff4ecab4e9d3a16a97f567e Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 3 Nov 2023 21:18:02 +0800 Subject: [PATCH 45/57] feat: multi thread output moderation --- .../callback_handler/llm_callback_handler.py | 123 +++++++++++++++--- 1 file changed, 108 insertions(+), 15 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 3cc025890e02e2..06d125e9f42d7a 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -1,6 +1,9 @@ import logging +import threading +import time from typing import Any, Dict, List, Union +from flask import Flask, current_app from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult, BaseMessage from pydantic import BaseModel @@ -32,9 +35,12 @@ def __init__(self, model_instance: BaseLLM, app_model_config = self.conversation_message_task.app_model_config sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict + self.is_moderation_working = False + self.direct_output_response = None self.moderation_rule = None - self.moderation_buffer = '' self.moderation_chunk = '' + self.moderation_buffer = '' + self.moderation_thread = None if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"): self.moderation_rule = ModerationRule( type=sensitive_word_avoidance_dict.get("type"), @@ -81,13 +87,16 @@ def on_llm_start( def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: if not self.conversation_message_task.streaming: - moderation_result = self.moderation_completion(response.generations[0][0].text, True) + moderation_result = self.moderation_completion_async(response.generations[0][0].text, True) if not moderation_result: self.llm_message.completion = response.generations[0][0].text self.conversation_message_task.append_message_text(self.llm_message.completion) else: - self.moderation_completion(self.llm_message.completion, True) + if len(self.llm_message.completion) < 300: + self.moderation_completion_async(self.llm_message.completion, True) + elif self.moderation_chunk: + self.moderation_completion_async(self.moderation_chunk, True) if response.llm_output and 'token_usage' in response.llm_output: if 'prompt_tokens' in response.llm_output['token_usage']: @@ -102,10 +111,17 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)]) + while self.is_moderation_working: + time.sleep(0.1) + + if self.direct_output_response: + raise ConversationTaskInterruptException() + self.conversation_message_task.save_message(self.llm_message) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - self.moderation_completion(token) + if self.direct_output_response: + raise ConversationTaskInterruptException() try: self.conversation_message_task.append_message_text(token) @@ -113,12 +129,15 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.on_llm_error(error=ex) raise ex + self.moderation_completion_async(token) + self.llm_message.completion += token def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" + self.is_moderation_working = False if isinstance(error, ConversationTaskStoppedException): if self.conversation_message_task.streaming: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( @@ -126,7 +145,11 @@ def on_llm_error( ) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) if isinstance(error, ConversationTaskInterruptException): - pass + self.llm_message.completion = self.direct_output_response + self.llm_message.completion_tokens = self.model_instance.get_num_tokens( + [PromptMessage(content=self.llm_message.completion)] + ) + self.conversation_message_task.save_message(llm_message=self.llm_message) else: logging.debug("on_llm_error: %s", error) @@ -141,13 +164,14 @@ def moderation_completion(self, token: str, no_chunk: bool = False) -> bool: return False if not no_chunk: - if len(self.moderation_chunk) < 50: - self.moderation_chunk += token + self.moderation_chunk += token + self.moderation_buffer += token + if len(self.moderation_chunk) < 300: return False else: + self.moderation_buffer += token self.moderation_chunk = token - moderation_chunk = self.moderation_chunk self.moderation_chunk = '' try: @@ -157,22 +181,22 @@ def moderation_completion(self, token: str, no_chunk: bool = False) -> bool: config=self.moderation_rule.config ) - result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_chunk) + logging.info('Moderation params: %s', self.moderation_buffer) + result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(self.moderation_buffer) if not result.flagged: return False if result.action == ModerationAction.DIRECT_OUTPUT: - self.is_interrupt = True self.llm_message.completion = result.preset_response else: - self.llm_message.completion = self.moderation_buffer + result.text + self.moderation_chunk + self.llm_message.completion = result.text + self.moderation_chunk if self.conversation_message_task.streaming: # trigger replace event - logging.debug("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) + logging.info("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) self.conversation_message_task.on_message_replace(self.llm_message.completion) - if result.action == ModerationOutputsAction.DIRECT_OUTPUT: + if result.action == ModerationAction.DIRECT_OUTPUT: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)] ) @@ -183,7 +207,76 @@ def moderation_completion(self, token: str, no_chunk: bool = False) -> bool: except Exception as e: logging.error("Moderation Output error: %s", e) return False - finally: - self.moderation_buffer += moderation_chunk return True + + def moderation_completion_async(self, token: str, no_chunk: bool = False) -> bool: + """ + Moderation for outputs. + + :param token: LLM output content + :return: bool + """ + if not self.moderation_rule: + return False + + if not no_chunk: + self.moderation_chunk += token + self.moderation_buffer += token + if len(self.moderation_chunk) < 300: + return False + else: + self.moderation_buffer += token + self.moderation_chunk = token + + self.moderation_chunk = '' + + if not self.moderation_thread: + self.moderation_thread = threading.Thread(target=self.moderation_worker, kwargs={ + 'flask_app': current_app._get_current_object() + }) + + self.moderation_thread.start() + + def moderation_worker(self, flask_app: Flask): + with flask_app.app_context(): + self.is_moderation_working = True + current_length = 0 + while self.is_moderation_working: + moderation_buffer = self.moderation_buffer + buffer_length = len(moderation_buffer) + if buffer_length - current_length < 300: + if buffer_length - current_length == 0: + self.is_moderation_working = False + break + + time.sleep(0.1) + continue + + current_length = buffer_length + + try: + moderation_factory = ModerationFactory( + name=self.moderation_rule.type, + tenant_id=self.conversation_message_task.tenant_id, + config=self.moderation_rule.config + ) + + logging.info('Moderation params: %s', moderation_buffer) + result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) + if not result.flagged: + continue + + if result.action == ModerationAction.DIRECT_OUTPUT: + self.is_moderation_working = False + self.llm_message.completion = result.preset_response + self.direct_output_response = result.preset_response + else: + self.llm_message.completion = result.text + self.moderation_buffer[len(moderation_buffer):] + + if self.conversation_message_task.streaming: + # trigger replace event + logging.info("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) + self.conversation_message_task.on_message_replace(self.llm_message.completion) + except Exception as e: + logging.error("Moderation Output error: %s", e) From 5c01a3eb35afcb8796e8b54ea97c454659916335 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 3 Nov 2023 21:29:36 +0800 Subject: [PATCH 46/57] fix: bug --- api/core/callback_handler/llm_callback_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 06d125e9f42d7a..10b9154e4b11b4 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -258,6 +258,7 @@ def moderation_worker(self, flask_app: Flask): try: moderation_factory = ModerationFactory( name=self.moderation_rule.type, + app_id=self.conversation_message_task.app.id, tenant_id=self.conversation_message_task.tenant_id, config=self.moderation_rule.config ) From 1be2578d546609eb64c18aa9d541a9dae5103b2a Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Sat, 4 Nov 2023 17:13:33 +0800 Subject: [PATCH 47/57] fix bug. --- api/controllers/console/extension.py | 6 ++---- api/services/api_based_extension_service.py | 5 ++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index c2feff6d716da9..d01bc9a80aa298 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -55,7 +55,7 @@ def post(self): api_key=args['api_key'] ) - return APIBasedExtensionService.save(extension_data, True) + return APIBasedExtensionService.save(extension_data) class APIBasedExtensionDetailAPI(Resource): @@ -88,12 +88,10 @@ def post(self, id): extension_data_from_db.name = args['name'] extension_data_from_db.api_endpoint = args['api_endpoint'] - need_encrypt = False if args['api_key'] != '[__HIDDEN__]': - need_encrypt = True extension_data_from_db.api_key = args['api_key'] - return APIBasedExtensionService.save(extension_data_from_db, need_encrypt) + return APIBasedExtensionService.save(extension_data_from_db) @setup_required @login_required diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 53e08de47a676f..867ec5b5dedc75 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -18,11 +18,10 @@ def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: return extension_list @classmethod - def save(cls, extension_data: APIBasedExtension, need_encrypt: bool) -> APIBasedExtension: + def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: cls._validation(extension_data) - if need_encrypt: - extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) + extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) db.session.add(extension_data) db.session.commit() From 6f56b01f849454ac0636978b5cb6bbd462c8101f Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Sat, 4 Nov 2023 18:07:56 +0800 Subject: [PATCH 48/57] add validation. --- api/core/completion.py | 2 +- api/core/moderation/api/api.py | 2 +- api/core/moderation/base.py | 17 +++++++++++++---- api/core/moderation/keywords/keywords.py | 3 +++ 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/api/core/completion.py b/api/core/completion.py index 05ea1955ceac4c..cbb09208cfbbaf 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -160,7 +160,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer return @classmethod - def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str) -> Tuple[dict, str]: + def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): if not app_model_config.sensitive_word_avoidance_dict['enabled']: return inputs, query diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 42a156ff11b556..61944d2684cb3e 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -39,7 +39,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: extension = cls._get_api_based_extension(tenant_id, api_based_extension_id) if not extension: - raise ValueError("api_based_extension_id is invalid") + raise ValueError("API-based Extension not found. Please check it again.") def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index bffc05d7151334..7c07bc29b1dcb5 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -94,11 +94,20 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r if not is_preset_response_required: return - if inputs_config_enabled and not inputs_config.get("preset_response"): - raise ValueError("inputs_config.preset_response is required") + if inputs_config_enabled: + if not inputs_config.get("preset_response"): + raise ValueError("inputs_config.preset_response is required") + + if len(inputs_config.get("preset_response")) > 100: + raise ValueError("inputs_config.preset_response must be less than 100 characters") + - if outputs_config_enabled and not outputs_config.get("preset_response"): - raise ValueError("outputs_config.preset_response is required") + if outputs_config_enabled: + if not outputs_config.get("preset_response"): + raise ValueError("outputs_config.preset_response is required") + + if len(outputs_config.get("preset_response")) > 100: + raise ValueError("outputs_config.preset_response must be less than 100 characters") class ModerationException(Exception): pass \ No newline at end of file diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index d30f044b80bb30..2743ecfbb70158 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -19,6 +19,9 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") + + if len(config.get("keywords")) > 1000: + raise ValueError("keywords length must be less than 1000") def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False From 1b735a17599871980e265d08941ca083df8eb2a1 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Sun, 5 Nov 2023 01:21:59 +0800 Subject: [PATCH 49/57] format. --- api/controllers/console/extension.py | 17 +++++++------ api/controllers/web/completion.py | 2 +- api/core/completion.py | 8 +++--- api/core/extension/extension.py | 1 - api/core/moderation/api/api.py | 25 ++++++++----------- api/core/moderation/base.py | 14 +++++------ api/core/moderation/factory.py | 4 +-- api/core/moderation/keywords/keywords.py | 10 +++----- .../openai_moderation/openai_moderation.py | 11 ++++---- api/fields/api_based_extension_fields.py | 3 ++- api/models/api_based_extension.py | 2 +- api/services/code_based_extension_service.py | 2 +- api/services/completion_service.py | 10 ++++---- api/services/moderation_service.py | 5 ++-- 14 files changed, 53 insertions(+), 61 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index d01bc9a80aa298..50b33e39ad4c9c 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -25,7 +25,7 @@ def get(self): 'module': args['module'], 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) } - + class APIBasedExtensionAPI(Resource): @@ -36,7 +36,7 @@ class APIBasedExtensionAPI(Resource): def get(self): tenant_id = current_user.current_tenant_id return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) - + @setup_required @login_required @account_initialization_required @@ -56,9 +56,10 @@ def post(self): ) return APIBasedExtensionService.save(extension_data) - + + class APIBasedExtensionDetailAPI(Resource): - + @setup_required @login_required @account_initialization_required @@ -66,9 +67,9 @@ class APIBasedExtensionDetailAPI(Resource): def get(self, id): api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id - + return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) - + @setup_required @login_required @account_initialization_required @@ -103,11 +104,11 @@ def delete(self, id): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) APIBasedExtensionService.delete(extension_data_from_db) - + return {'result': 'success'} api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') api.add_resource(APIBasedExtensionAPI, '/api-based-extension') -api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') \ No newline at end of file +api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 79c0c542d17852..579c1761c548d9 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -139,7 +139,7 @@ def post(self, app_model, end_user, task_id): return {'result': 'success'}, 200 -def compact_response(response: Union[dict | Generator]) -> Response: +def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') else: diff --git a/api/core/completion.py b/api/core/completion.py index cbb09208cfbbaf..7eaacd486f9e3c 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -84,7 +84,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer try: chain_callback = MainChainGatherCallbackHandler(conversation_message_task) - + try: # process sensitive_word_avoidance inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query) @@ -158,7 +158,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer logging.warning(f'ChunkedEncodingError: {e}') conversation_message_task.end() return - + @classmethod def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): if not app_model_config.sensitive_word_avoidance_dict['enabled']: @@ -179,8 +179,6 @@ def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: Ap query = moderation_result.query return inputs, query - - @classmethod def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict], @@ -256,7 +254,7 @@ def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str: if app.mode != 'completion': return query - + return inputs.get(app_model_config.dataset_query_variable, "") @classmethod diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 845484cb1a1da6..6517e41ccd1120 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -45,4 +45,3 @@ def validate_form_schema(self, module: ExtensionModule, extension_name: str, con form_schema = module_extension.form_schema # TODO validate form_schema - diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 61944d2684cb3e..9ef584cd1ad5b4 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,7 +1,6 @@ -from typing import Union from pydantic import BaseModel -from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction, ModerationException +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint from core.helper.encrypter import decrypt_token from extensions.ext_database import db @@ -51,12 +50,11 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu inputs=inputs, query=query ) - + result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict()) return ModerationInputsResult(**result) - + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) - def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False @@ -70,16 +68,15 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict()) return ModerationOutputsResult(**result) - - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: - extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) - requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) - - result = requestor.request(extension_point, params) - return result + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) + + result = requestor.request(extension_point, params) + return result @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: @@ -87,5 +84,5 @@ def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> API APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ).first() - - return extension \ No newline at end of file + + return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 7c07bc29b1dcb5..ce4e574038d05a 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -79,7 +79,7 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): raise ValueError("inputs_config must be a dict") - + # outputs_config outputs_config = config.get("outputs_config") if not isinstance(outputs_config, dict): @@ -93,21 +93,21 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r # preset_response if not is_preset_response_required: return - + if inputs_config_enabled: if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - + if len(inputs_config.get("preset_response")) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") - - + if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - + if len(outputs_config.get("preset_response")) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") + class ModerationException(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index e0ec5c3548c002..96bf2ab54b41eb 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.extension.extensible import ExtensionModule from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult from extensions.ext_code_based_extension import code_based_extension @@ -47,4 +45,4 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: :param text: LLM output content :return: """ - return self.__extension_instance.moderation_for_outputs(text) \ No newline at end of file + return self.__extension_instance.moderation_for_outputs(text) diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 2743ecfbb70158..168b9d43f806f7 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction @@ -19,7 +17,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - + if len(config.get("keywords")) > 1000: raise ValueError("keywords length must be less than 1000") @@ -36,7 +34,7 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu flagged = self._is_violated(inputs, keywords_list) return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) - + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" @@ -52,9 +50,9 @@ def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): if self._check_keywords_in_value(keywords_list, value): return True - + return False - + def _check_keywords_in_value(self, keywords_list, value): for keyword in keywords_list: if keyword.lower() in value.lower(): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 15eb9f9eac4f52..954800e36b39ee 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,5 @@ import openai import json -from typing import Optional from core.helper.encrypter import decrypt_token from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction @@ -32,7 +31,7 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu if query: inputs['query__'] = query flagged = self._is_violated(inputs) - + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: @@ -40,7 +39,7 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: preset_response = "" if self.config['outputs_config']['enabled']: - flagged = self._is_violated({ 'text': text }) + flagged = self._is_violated({'text': text}) preset_response = self.config['outputs_config']['preset_response'] return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) @@ -53,9 +52,9 @@ def _is_violated(self, inputs: dict): for result in moderation_result.results: if result['flagged']: return True - + return False - + def _get_openai_api_key(self) -> str: provider = db.session.query(Provider) \ .filter_by(tenant_id=self.tenant_id) \ @@ -67,4 +66,4 @@ def _get_openai_api_key(self) -> str: encrypted_config = json.loads(provider.encrypted_config) - return decrypt_token(self.tenant_id, encrypted_config['openai_api_key']) \ No newline at end of file + return decrypt_token(self.tenant_id, encrypted_config['openai_api_key']) diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 6dc0a0b823bc3d..f6a60ff02e83bb 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -7,10 +7,11 @@ class HiddenAPIKey(fields.Raw): def output(self, key, obj): return obj.api_key[:3] + '***' + obj.api_key[-3:] + api_based_extension_fields = { 'id': fields.String, 'name': fields.String, 'api_endpoint': fields.String, 'api_key': HiddenAPIKey, 'created_at': TimestampField -} \ No newline at end of file +} diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 28ff868c146ffb..e34cfb8f7b3371 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -24,4 +24,4 @@ class APIBasedExtension(db.Model): name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) \ No newline at end of file + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 38d8d112b513cb..7b0d50a835dbf8 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -2,7 +2,7 @@ class CodeBasedExtensionService: - + @staticmethod def get_code_based_extension(module: str) -> list[dict]: module_extensions = code_based_extension.module_extensions(module) diff --git a/api/services/completion_service.py b/api/services/completion_service.py index a26ba8613f5eaf..bba46047880e9b 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -29,9 +29,9 @@ class CompletionService: @classmethod - def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any, + def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, from_source: str, streaming: bool = True, - is_model_config_override: bool = False) -> Union[dict | Generator]: + is_model_config_override: bool = False) -> Union[dict, Generator]: # is streaming mode inputs = args['inputs'] query = args['query'] @@ -244,9 +244,9 @@ def close_pubsub(): return countdown_thread @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], + def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, streaming: bool = True, - retriever_from: str = 'dev') -> Union[dict | Generator]: + retriever_from: str = 'dev') -> Union[dict, Generator]: if not user: raise ValueError('user cannot be None') @@ -342,7 +342,7 @@ def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig) return filtered_inputs @classmethod - def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]: + def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]: generate_channel = list(pubsub.channels.keys())[0].decode('utf-8') if not streaming: try: diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index f933c9abc0a0fa..d0be7bca80aeb1 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -2,6 +2,7 @@ from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db + class ModerationService: def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: @@ -11,9 +12,9 @@ def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> Mode if not app_model_config: raise ValueError("app model config not found") - + name = app_model_config.sensitive_word_avoidance_dict['type'] config = app_model_config.sensitive_word_avoidance_dict['config'] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) - return moderation.moderation_for_outputs(text) \ No newline at end of file + return moderation.moderation_for_outputs(text) From e07fac5d78900cf3c91cff549a01cfbfaa458c0b Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Sun, 5 Nov 2023 16:34:46 +0800 Subject: [PATCH 50/57] format. --- api/services/api_based_extension_service.py | 29 +++++++++++---------- api/services/app_model_config_service.py | 26 +++++++++--------- api/services/completion_service.py | 2 +- 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 867ec5b5dedc75..d4e7d5be3d03f2 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -3,15 +3,16 @@ from core.helper.encrypter import encrypt_token, decrypt_token from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor + class APIBasedExtensionService: @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = db.session.query(APIBasedExtension) \ + extension_list = db.session.query(APIBasedExtension) \ .filter_by(tenant_id=tenant_id) \ .order_by(APIBasedExtension.created_at.desc()) \ .all() - + for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) @@ -38,27 +39,27 @@ def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedE .filter_by(tenant_id=tenant_id) \ .filter_by(id=api_based_extension_id) \ .first() - + if not extension: raise ValueError("API based extension is not found") - + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) - + return extension - + @classmethod def _validation(cls, extension_data: APIBasedExtension) -> None: # name if not extension_data.name: raise ValueError("name must not be empty") - + if not extension_data.id: # case one: check new data, name must be unique is_name_existed = db.session.query(APIBasedExtension) \ .filter_by(tenant_id=extension_data.tenant_id) \ .filter_by(name=extension_data.name) \ .first() - + if is_name_existed: raise ValueError("name must be unique, it is already existed") else: @@ -68,24 +69,24 @@ def _validation(cls, extension_data: APIBasedExtension) -> None: .filter_by(name=extension_data.name) \ .filter(APIBasedExtension.id != extension_data.id) \ .first() - + if is_name_existed: raise ValueError("name must be unique, it is already existed") # api_endpoint if not extension_data.api_endpoint: raise ValueError("api_endpoint must not be empty") - + # api_key if not extension_data.api_key: raise ValueError("api_key must not be empty") - + if len(extension_data.api_key) < 5: raise ValueError("api_key must be at least 5 characters") - + # check endpoint cls._ping_connection(extension_data) - + @staticmethod def _ping_connection(extension_data: APIBasedExtension) -> None: try: @@ -94,4 +95,4 @@ def _ping_connection(extension_data: APIBasedExtension) -> None: if resp.get('result') != 'pong': raise ValueError(resp) except Exception as e: - raise ValueError("connection error: {}".format(e)) \ No newline at end of file + raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 64809c8fc604d9..b9b4e84735f29a 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -59,7 +59,7 @@ def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") - + if len(cp["stop"]) > 4: raise ValueError("stop sequences must be less than 4") @@ -179,7 +179,7 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, model_ids = [m['id'] for m in model_list] if config["model"]["name"] not in model_ids: raise ValueError("model.name must be in the specified model list") - + # model.mode if 'mode' not in config['model'] or not config['model']["mode"]: config['model']["mode"] = "" @@ -307,7 +307,7 @@ def validate_configuration(cls, tenant_id: str, account: Account, config: dict, if not cls.is_dataset_exists(account, tool_item["id"]): raise ValueError("Dataset ID does not exist, please check your permission.") - + # dataset_query_variable cls.is_dataset_query_variable_valid(config, mode) @@ -363,10 +363,10 @@ def is_moderation_valid(cls, tenant_id: str, config: dict): if not config["sensitive_word_avoidance"]["enabled"]: return - + if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: raise ValueError("sensitive_word_avoidance.type is required") - + type = config["sensitive_word_avoidance"]["type"] config = config["sensitive_word_avoidance"]["config"] @@ -408,16 +408,15 @@ def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: # Only check when mode is completion if mode != 'completion': return - + agent_mode = config.get("agent_mode", {}) tools = agent_mode.get("tools", []) dataset_exists = "dataset" in str(tools) - + dataset_query_variable = config.get("dataset_query_variable") if dataset_exists and not dataset_query_variable: raise ValueError("Dataset query variable is required when dataset is exist") - @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: @@ -427,7 +426,7 @@ def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: if config['prompt_type'] not in ['simple', 'advanced']: raise ValueError("prompt_type must be in ['simple', 'advanced']") - + # chat_prompt_config if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: config["chat_prompt_config"] = {} @@ -441,7 +440,7 @@ def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - + # dataset_configs if 'dataset_configs' not in config or not config["dataset_configs"]: config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} @@ -452,10 +451,10 @@ def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: if config['prompt_type'] == 'advanced': if not config['chat_prompt_config'] and not config['completion_prompt_config']: raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") - + if config['model']["mode"] not in ['chat', 'completion']: raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") - + if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value: user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] @@ -466,9 +465,8 @@ def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: if not assistant_prefix: config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' - if config['model']["mode"] == ModelMode.CHAT.value: prompt_list = config['chat_prompt_config']['prompt'] if len(prompt_list) > 10: - raise ValueError("prompt messages must be less than 10") \ No newline at end of file + raise ValueError("prompt messages must be less than 10") diff --git a/api/services/completion_service.py b/api/services/completion_service.py index bba46047880e9b..81e35c85939770 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -235,7 +235,7 @@ def close_pubsub(): PubHandler.stop(user, generate_task_id) try: pubsub.close() - except: + except Exception: pass countdown_thread = threading.Thread(target=close_pubsub) From c88371abc9105c607f4c0cad09fec170f6c0e6a8 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 5 Nov 2023 17:34:48 +0800 Subject: [PATCH 51/57] fix: bug --- .../callback_handler/llm_callback_handler.py | 85 +++---------------- 1 file changed, 13 insertions(+), 72 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 10b9154e4b11b4..3cc254ce44d8b6 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -35,7 +35,6 @@ def __init__(self, model_instance: BaseLLM, app_model_config = self.conversation_message_task.app_model_config sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict - self.is_moderation_working = False self.direct_output_response = None self.moderation_rule = None self.moderation_chunk = '' @@ -111,8 +110,8 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)]) - while self.is_moderation_working: - time.sleep(0.1) + if self.moderation_thread: + self.moderation_thread.join() if self.direct_output_response: raise ConversationTaskInterruptException() @@ -120,24 +119,24 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.conversation_message_task.save_message(self.llm_message) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - if self.direct_output_response: - raise ConversationTaskInterruptException() - try: + if self.direct_output_response: + raise ConversationTaskInterruptException() + self.conversation_message_task.append_message_text(token) + self.moderation_completion_async(token) + self.llm_message.completion += token except ConversationTaskStoppedException as ex: self.on_llm_error(error=ex) raise ex - - self.moderation_completion_async(token) - - self.llm_message.completion += token + except ConversationTaskInterruptException as ex: + self.on_llm_error(error=ex) + raise ex def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" - self.is_moderation_working = False if isinstance(error, ConversationTaskStoppedException): if self.conversation_message_task.streaming: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( @@ -153,63 +152,6 @@ def on_llm_error( else: logging.debug("on_llm_error: %s", error) - def moderation_completion(self, token: str, no_chunk: bool = False) -> bool: - """ - Moderation for outputs. - - :param token: LLM output content - :return: bool - """ - if not self.moderation_rule: - return False - - if not no_chunk: - self.moderation_chunk += token - self.moderation_buffer += token - if len(self.moderation_chunk) < 300: - return False - else: - self.moderation_buffer += token - self.moderation_chunk = token - - self.moderation_chunk = '' - - try: - moderation_factory = ModerationFactory( - name=self.moderation_rule.type, - tenant_id=self.conversation_message_task.tenant_id, - config=self.moderation_rule.config - ) - - logging.info('Moderation params: %s', self.moderation_buffer) - result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(self.moderation_buffer) - if not result.flagged: - return False - - if result.action == ModerationAction.DIRECT_OUTPUT: - self.llm_message.completion = result.preset_response - else: - self.llm_message.completion = result.text + self.moderation_chunk - - if self.conversation_message_task.streaming: - # trigger replace event - logging.info("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) - self.conversation_message_task.on_message_replace(self.llm_message.completion) - - if result.action == ModerationAction.DIRECT_OUTPUT: - self.llm_message.completion_tokens = self.model_instance.get_num_tokens( - [PromptMessage(content=self.llm_message.completion)] - ) - self.conversation_message_task.save_message(llm_message=self.llm_message) - raise ConversationTaskInterruptException() - except ConversationTaskInterruptException as e: - raise e - except Exception as e: - logging.error("Moderation Output error: %s", e) - return False - - return True - def moderation_completion_async(self, token: str, no_chunk: bool = False) -> bool: """ Moderation for outputs. @@ -240,14 +182,12 @@ def moderation_completion_async(self, token: str, no_chunk: bool = False) -> boo def moderation_worker(self, flask_app: Flask): with flask_app.app_context(): - self.is_moderation_working = True current_length = 0 - while self.is_moderation_working: + while True: moderation_buffer = self.moderation_buffer buffer_length = len(moderation_buffer) if buffer_length - current_length < 300: if buffer_length - current_length == 0: - self.is_moderation_working = False break time.sleep(0.1) @@ -269,7 +209,6 @@ def moderation_worker(self, flask_app: Flask): continue if result.action == ModerationAction.DIRECT_OUTPUT: - self.is_moderation_working = False self.llm_message.completion = result.preset_response self.direct_output_response = result.preset_response else: @@ -279,5 +218,7 @@ def moderation_worker(self, flask_app: Flask): # trigger replace event logging.info("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) self.conversation_message_task.on_message_replace(self.llm_message.completion) + if result.action == ModerationAction.DIRECT_OUTPUT: + break except Exception as e: logging.error("Moderation Output error: %s", e) From 5918902e2f54857a0fc606ba70710c41664a9fcb Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 5 Nov 2023 18:48:47 +0800 Subject: [PATCH 52/57] fix: bug --- .../callback_handler/llm_callback_handler.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 3cc254ce44d8b6..a8baeb28ff0bc1 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -39,6 +39,7 @@ def __init__(self, model_instance: BaseLLM, self.moderation_rule = None self.moderation_chunk = '' self.moderation_buffer = '' + self.moderation_final_chunk = False self.moderation_thread = None if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"): self.moderation_rule = ModerationRule( @@ -92,10 +93,7 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.conversation_message_task.append_message_text(self.llm_message.completion) else: - if len(self.llm_message.completion) < 300: - self.moderation_completion_async(self.llm_message.completion, True) - elif self.moderation_chunk: - self.moderation_completion_async(self.moderation_chunk, True) + self.moderation_completion_async(self.llm_message.completion, True) if response.llm_output and 'token_usage' in response.llm_output: if 'prompt_tokens' in response.llm_output['token_usage']: @@ -114,24 +112,25 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.moderation_thread.join() if self.direct_output_response: - raise ConversationTaskInterruptException() + ex = ConversationTaskInterruptException() + self.on_llm_error(error=ex) + raise ex self.conversation_message_task.save_message(self.llm_message) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - try: - if self.direct_output_response: - raise ConversationTaskInterruptException() + if self.direct_output_response: + ex = ConversationTaskInterruptException() + self.on_llm_error(error=ex) + raise ex + try: self.conversation_message_task.append_message_text(token) self.moderation_completion_async(token) self.llm_message.completion += token except ConversationTaskStoppedException as ex: self.on_llm_error(error=ex) raise ex - except ConversationTaskInterruptException as ex: - self.on_llm_error(error=ex) - raise ex def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @@ -157,6 +156,7 @@ def moderation_completion_async(self, token: str, no_chunk: bool = False) -> boo Moderation for outputs. :param token: LLM output content + :param no_chunk: whether to chunk the token :return: bool """ if not self.moderation_rule: @@ -168,8 +168,8 @@ def moderation_completion_async(self, token: str, no_chunk: bool = False) -> boo if len(self.moderation_chunk) < 300: return False else: - self.moderation_buffer += token - self.moderation_chunk = token + self.moderation_buffer = token + self.moderation_final_chunk = True self.moderation_chunk = '' @@ -186,7 +186,7 @@ def moderation_worker(self, flask_app: Flask): while True: moderation_buffer = self.moderation_buffer buffer_length = len(moderation_buffer) - if buffer_length - current_length < 300: + if not self.moderation_final_chunk and buffer_length - current_length < 300: if buffer_length - current_length == 0: break From 7fb83e8d467598bf381ad7b6fb8c83d57a671eff Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 6 Nov 2023 08:35:49 +0800 Subject: [PATCH 53/57] feat: refactor output moderation --- .../callback_handler/llm_callback_handler.py | 222 +++++++++++------- 1 file changed, 137 insertions(+), 85 deletions(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index a8baeb28ff0bc1..78e1f34de68672 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Optional from flask import Flask, current_app from langchain.callbacks.base import BaseCallbackHandler @@ -32,19 +32,22 @@ def __init__(self, model_instance: BaseLLM, self.start_at = None self.conversation_message_task = conversation_message_task + self.output_moderation_handler = None + self.init_output_moderation() + + def init_output_moderation(self): app_model_config = self.conversation_message_task.app_model_config sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict - self.direct_output_response = None - self.moderation_rule = None - self.moderation_chunk = '' - self.moderation_buffer = '' - self.moderation_final_chunk = False - self.moderation_thread = None if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"): - self.moderation_rule = ModerationRule( - type=sensitive_word_avoidance_dict.get("type"), - config=sensitive_word_avoidance_dict.get("config") + self.output_moderation_handler = OutputModerationHandler( + tenant_id=self.conversation_message_task.tenant_id, + app_id=self.conversation_message_task.app.id, + rule=ModerationRule( + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config") + ), + on_message_replace_func=self.conversation_message_task.on_message_replace ) @property @@ -86,14 +89,18 @@ def on_llm_start( self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - if not self.conversation_message_task.streaming: - moderation_result = self.moderation_completion_async(response.generations[0][0].text, True) - if not moderation_result: - self.llm_message.completion = response.generations[0][0].text + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() - self.conversation_message_task.append_message_text(self.llm_message.completion) + self.llm_message.completion = self.output_moderation_handler.moderation_completion( + completion=response.generations[0][0].text, + public_event=True if self.conversation_message_task.streaming else False + ) else: - self.moderation_completion_async(self.llm_message.completion, True) + self.llm_message.completion = response.generations[0][0].text + + if not self.conversation_message_task.streaming: + self.conversation_message_task.append_message_text(self.llm_message.completion) if response.llm_output and 'token_usage' in response.llm_output: if 'prompt_tokens' in response.llm_output['token_usage']: @@ -108,26 +115,21 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)]) - if self.moderation_thread: - self.moderation_thread.join() - - if self.direct_output_response: - ex = ConversationTaskInterruptException() - self.on_llm_error(error=ex) - raise ex - self.conversation_message_task.save_message(self.llm_message) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - if self.direct_output_response: + if self.output_moderation_handler and self.output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output ex = ConversationTaskInterruptException() self.on_llm_error(error=ex) raise ex try: self.conversation_message_task.append_message_text(token) - self.moderation_completion_async(token) self.llm_message.completion += token + + if self.output_moderation_handler: + self.output_moderation_handler.append_new_token(token) except ConversationTaskStoppedException as ex: self.on_llm_error(error=ex) raise ex @@ -136,6 +138,9 @@ def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() + if isinstance(error, ConversationTaskStoppedException): if self.conversation_message_task.streaming: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( @@ -143,7 +148,7 @@ def on_llm_error( ) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) if isinstance(error, ConversationTaskInterruptException): - self.llm_message.completion = self.direct_output_response + self.llm_message.completion = self.output_moderation_handler.get_final_output() self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)] ) @@ -151,74 +156,121 @@ def on_llm_error( else: logging.debug("on_llm_error: %s", error) - def moderation_completion_async(self, token: str, no_chunk: bool = False) -> bool: - """ - Moderation for outputs. - - :param token: LLM output content - :param no_chunk: whether to chunk the token - :return: bool - """ - if not self.moderation_rule: - return False - - if not no_chunk: - self.moderation_chunk += token - self.moderation_buffer += token - if len(self.moderation_chunk) < 300: - return False + +class OutputModerationHandler(BaseModel): + BUFFER_SIZE: int = 300 + + tenant_id: str + app_id: str + + rule: ModerationRule + on_message_replace_func: Any + + thread: Optional[threading.Thread] = None + thread_running: bool = True + buffer: str = '' + is_final_chunk: bool = False + final_output: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + + def should_direct_output(self): + return self.final_output is not None + + def get_final_output(self): + return self.final_output + + def append_new_token(self, token: str): + self.buffer += token + + if not self.thread: + self.thread = self.start_thread() + + def moderation_completion(self, completion: str, public_event: bool = False) -> str: + self.buffer = completion + self.is_final_chunk = True + + result = self.moderation( + tenant_id=self.tenant_id, + app_id=self.app_id, + moderation_buffer=completion + ) + + if not result or not result.flagged: + return completion + + if result.action == ModerationAction.DIRECT_OUTPUT: + final_output = result.preset_response else: - self.moderation_buffer = token - self.moderation_final_chunk = True + final_output = result.text - self.moderation_chunk = '' + if public_event: + self.on_message_replace_func(final_output) - if not self.moderation_thread: - self.moderation_thread = threading.Thread(target=self.moderation_worker, kwargs={ - 'flask_app': current_app._get_current_object() - }) + return final_output - self.moderation_thread.start() + def start_thread(self) -> threading.Thread: + thread = threading.Thread(target=self.worker, kwargs={ + 'flask_app': current_app._get_current_object() + }) - def moderation_worker(self, flask_app: Flask): + thread.start() + + return thread + + def stop_thread(self): + if self.thread and self.thread.is_alive(): + self.thread_running = False + + def worker(self, flask_app: Flask): with flask_app.app_context(): current_length = 0 - while True: - moderation_buffer = self.moderation_buffer + while self.thread_running: + moderation_buffer = self.buffer buffer_length = len(moderation_buffer) - if not self.moderation_final_chunk and buffer_length - current_length < 300: - if buffer_length - current_length == 0: - break + if not self.is_final_chunk: + chunk_length = buffer_length - current_length + if 0 <= chunk_length < self.BUFFER_SIZE: + time.sleep(1) + continue - time.sleep(0.1) + current_length = buffer_length + + result = self.moderation( + tenant_id=self.tenant_id, + app_id=self.app_id, + moderation_buffer=moderation_buffer + ) + + if not result or not result.flagged: continue - current_length = buffer_length + if result.action == ModerationAction.DIRECT_OUTPUT: + final_output = result.preset_response + else: + final_output = result.text + self.buffer[len(moderation_buffer):] - try: - moderation_factory = ModerationFactory( - name=self.moderation_rule.type, - app_id=self.conversation_message_task.app.id, - tenant_id=self.conversation_message_task.tenant_id, - config=self.moderation_rule.config - ) - - logging.info('Moderation params: %s', moderation_buffer) - result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) - if not result.flagged: - continue + # trigger replace event + if self.thread_running: + self.on_message_replace_func(final_output) + + if result.action == ModerationAction.DIRECT_OUTPUT: + self.final_output = final_output + break + + def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: + try: + moderation_factory = ModerationFactory( + name=self.rule.type, + app_id=app_id, + tenant_id=tenant_id, + config=self.rule.config + ) + + result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) + return result + except Exception as e: + logging.error("Moderation Output error: %s", e) - if result.action == ModerationAction.DIRECT_OUTPUT: - self.llm_message.completion = result.preset_response - self.direct_output_response = result.preset_response - else: - self.llm_message.completion = result.text + self.moderation_buffer[len(moderation_buffer):] - - if self.conversation_message_task.streaming: - # trigger replace event - logging.info("Moderation %s replace event: %s", result.action.value, self.llm_message.completion) - self.conversation_message_task.on_message_replace(self.llm_message.completion) - if result.action == ModerationAction.DIRECT_OUTPUT: - break - except Exception as e: - logging.error("Moderation Output error: %s", e) + return None From 176a0128279e5a2ac32c8718d60cb89aa5f2f9b7 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 6 Nov 2023 08:48:56 +0800 Subject: [PATCH 54/57] feat: add buffer size setting to env --- api/config.py | 4 ++++ api/core/callback_handler/llm_callback_handler.py | 10 ++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/api/config.py b/api/config.py index d9be647c0ed519..ad099ab15830da 100644 --- a/api/config.py +++ b/api/config.py @@ -57,6 +57,7 @@ 'CLEAN_DAY_SETTING': 30, 'UPLOAD_FILE_SIZE_LIMIT': 15, 'UPLOAD_FILE_BATCH_LIMIT': 5, + 'OUTPUT_MODERATION_BUFFER_SIZE': 300 } @@ -228,6 +229,9 @@ def __init__(self): self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) + # moderation settings + self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE')) + class CloudEditionConfig(Config): diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 78e1f34de68672..136b8c893ecb46 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -158,7 +158,7 @@ def on_llm_error( class OutputModerationHandler(BaseModel): - BUFFER_SIZE: int = 300 + DEFAULT_BUFFER_SIZE: int = 300 tenant_id: str app_id: str @@ -211,8 +211,10 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> return final_output def start_thread(self) -> threading.Thread: + buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE)) thread = threading.Thread(target=self.worker, kwargs={ - 'flask_app': current_app._get_current_object() + 'flask_app': current_app._get_current_object(), + 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE }) thread.start() @@ -223,7 +225,7 @@ def stop_thread(self): if self.thread and self.thread.is_alive(): self.thread_running = False - def worker(self, flask_app: Flask): + def worker(self, flask_app: Flask, buffer_size: int): with flask_app.app_context(): current_length = 0 while self.thread_running: @@ -231,7 +233,7 @@ def worker(self, flask_app: Flask): buffer_length = len(moderation_buffer) if not self.is_final_chunk: chunk_length = buffer_length - current_length - if 0 <= chunk_length < self.BUFFER_SIZE: + if 0 <= chunk_length < buffer_size: time.sleep(1) continue From c5206d649b3532d751321d0c680eed712196b79c Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 6 Nov 2023 12:55:05 +0800 Subject: [PATCH 55/57] feat: optimize final output --- api/core/callback_handler/llm_callback_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 136b8c893ecb46..109dda68cdc07b 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -250,6 +250,7 @@ def worker(self, flask_app: Flask, buffer_size: int): if result.action == ModerationAction.DIRECT_OUTPUT: final_output = result.preset_response + self.final_output = final_output else: final_output = result.text + self.buffer[len(moderation_buffer):] @@ -258,7 +259,6 @@ def worker(self, flask_app: Flask, buffer_size: int): self.on_message_replace_func(final_output) if result.action == ModerationAction.DIRECT_OUTPUT: - self.final_output = final_output break def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: From 5a15f2751e30eb9d3a8e01ef02b0c4c7e56ddedd Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Mon, 6 Nov 2023 18:54:57 +0800 Subject: [PATCH 56/57] bug fixed. --- .../openai_moderation/openai_moderation.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 954800e36b39ee..7dea2c6bf75513 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -3,8 +3,7 @@ from core.helper.encrypter import decrypt_token from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction -from extensions.ext_database import db -from models.provider import Provider +from core.model_providers.model_factory import ModelFactory class OpenAIModeration(Moderation): @@ -56,14 +55,10 @@ def _is_violated(self, inputs: dict): return False def _get_openai_api_key(self) -> str: - provider = db.session.query(Provider) \ - .filter_by(tenant_id=self.tenant_id) \ - .filter_by(provider_name="openai") \ - .first() - - if not provider: + model_class_obj = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation") + if not model_class_obj: raise ValueError("openai provider is not configured") - encrypted_config = json.loads(provider.encrypted_config) + encrypted_config = json.loads(model_class_obj.model_provider.provider.encrypted_config) return decrypt_token(self.tenant_id, encrypted_config['openai_api_key']) From 040672cf6b453e5103055a8feaa17efd1a092c5f Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Mon, 6 Nov 2023 19:21:29 +0800 Subject: [PATCH 57/57] bug fix. --- .../openai_moderation/openai_moderation.py | 26 +++---------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 7dea2c6bf75513..c5817b19011f59 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,7 +1,3 @@ -import openai -import json - -from core.helper.encrypter import decrypt_token from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction from core.model_providers.model_factory import ModelFactory @@ -44,21 +40,7 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) def _is_violated(self, inputs: dict): - - openai_api_key = self._get_openai_api_key() - moderation_result = openai.Moderation.create(input=list(inputs.values()), api_key=openai_api_key) - - for result in moderation_result.results: - if result['flagged']: - return True - - return False - - def _get_openai_api_key(self) -> str: - model_class_obj = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation") - if not model_class_obj: - raise ValueError("openai provider is not configured") - - encrypted_config = json.loads(model_class_obj.model_provider.provider.encrypted_config) - - return decrypt_token(self.tenant_id, encrypted_config['openai_api_key']) + text = '\n'.join(inputs.values()) + openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation") + is_not_invalid = openai_moderation.run(text) + return not is_not_invalid