From db43ed6f4172ad357e2c3cd19bb263cb569a9d4f Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Mon, 6 Nov 2023 19:36:16 +0800 Subject: [PATCH] feat: add api-based extension & external data tool & moderation backend (#1403) Co-authored-by: takatost --- {.vscode => api/.vscode}/launch.json | 2 +- api/app.py | 3 +- api/config.py | 4 + api/controllers/console/__init__.py | 2 +- api/controllers/console/explore/parameter.py | 4 +- api/controllers/console/extension.py | 114 +++++++++++ api/controllers/service_api/app/app.py | 4 +- api/controllers/service_api/app/completion.py | 1 - api/controllers/web/app.py | 4 +- api/controllers/web/completion.py | 2 +- api/core/__init__.py | 1 + .../callback_handler/llm_callback_handler.py | 189 +++++++++++++++++- .../chain/sensitive_word_avoidance_chain.py | 92 --------- api/core/completion.py | 157 ++++++++++++--- api/core/conversation_message_task.py | 26 +++ api/core/extension/__init__.py | 0 .../api_based_extension_requestor.py | 62 ++++++ api/core/extension/extensible.py | 111 ++++++++++ api/core/extension/extension.py | 47 +++++ api/core/external_data_tool/__init__.py | 0 api/core/external_data_tool/api/__builtin__ | 1 + api/core/external_data_tool/api/__init__.py | 0 api/core/external_data_tool/api/api.py | 92 +++++++++ api/core/external_data_tool/base.py | 45 +++++ api/core/external_data_tool/factory.py | 40 ++++ api/core/moderation/__init__.py | 0 api/core/moderation/api/__builtin__ | 1 + api/core/moderation/api/__init__.py | 0 api/core/moderation/api/api.py | 88 ++++++++ api/core/moderation/base.py | 113 +++++++++++ api/core/moderation/factory.py | 48 +++++ api/core/moderation/keywords/__builtin__ | 1 + api/core/moderation/keywords/__init__.py | 0 api/core/moderation/keywords/keywords.py | 60 ++++++ .../moderation/openai_moderation/__builtin__ | 1 + .../moderation/openai_moderation/__init__.py | 0 .../openai_moderation/openai_moderation.py | 46 +++++ api/core/orchestrator_rule_parser.py | 47 ----- api/extensions/ext_code_based_extension.py | 8 + api/fields/api_based_extension_fields.py | 17 ++ api/fields/app_fields.py | 1 + .../968fff4c0ab9_add_api_based_extension.py | 45 +++++ ...e_add_external_data_tools_in_app_model_.py | 32 +++ api/models/api_based_extension.py | 27 +++ api/models/model.py | 44 ++-- api/services/api_based_extension_service.py | 98 +++++++++ api/services/app_model_config_service.py | 143 ++++++++----- api/services/code_based_extension_service.py | 13 ++ api/services/completion_service.py | 37 +++- api/services/moderation_service.py | 20 ++ 50 files changed, 1622 insertions(+), 271 deletions(-) rename {.vscode => api/.vscode}/launch.json (94%) create mode 100644 api/controllers/console/extension.py delete mode 100644 api/core/chain/sensitive_word_avoidance_chain.py create mode 100644 api/core/extension/__init__.py create mode 100644 api/core/extension/api_based_extension_requestor.py create mode 100644 api/core/extension/extensible.py create mode 100644 api/core/extension/extension.py create mode 100644 api/core/external_data_tool/__init__.py create mode 100644 api/core/external_data_tool/api/__builtin__ create mode 100644 api/core/external_data_tool/api/__init__.py create mode 100644 api/core/external_data_tool/api/api.py create mode 100644 api/core/external_data_tool/base.py create mode 100644 api/core/external_data_tool/factory.py create mode 100644 api/core/moderation/__init__.py create mode 100644 api/core/moderation/api/__builtin__ create mode 100644 api/core/moderation/api/__init__.py create mode 100644 api/core/moderation/api/api.py create mode 100644 api/core/moderation/base.py create mode 100644 api/core/moderation/factory.py create mode 100644 api/core/moderation/keywords/__builtin__ create mode 100644 api/core/moderation/keywords/__init__.py create mode 100644 api/core/moderation/keywords/keywords.py create mode 100644 api/core/moderation/openai_moderation/__builtin__ create mode 100644 api/core/moderation/openai_moderation/__init__.py create mode 100644 api/core/moderation/openai_moderation/openai_moderation.py create mode 100644 api/extensions/ext_code_based_extension.py create mode 100644 api/fields/api_based_extension_fields.py create mode 100644 api/migrations/versions/968fff4c0ab9_add_api_based_extension.py create mode 100644 api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py create mode 100644 api/models/api_based_extension.py create mode 100644 api/services/api_based_extension_service.py create mode 100644 api/services/code_based_extension_service.py create mode 100644 api/services/moderation_service.py diff --git a/.vscode/launch.json b/api/.vscode/launch.json similarity index 94% rename from .vscode/launch.json rename to api/.vscode/launch.json index 515deb4c0cf85b..e3c1f797c61601 100644 --- a/.vscode/launch.json +++ b/api/.vscode/launch.json @@ -10,7 +10,7 @@ "request": "launch", "module": "flask", "env": { - "FLASK_APP": "api/app.py", + "FLASK_APP": "app.py", "FLASK_DEBUG": "1", "GEVENT_SUPPORT": "True" }, diff --git a/api/app.py b/api/app.py index ef32ad44e1f70c..bc0d25224e8b04 100644 --- a/api/app.py +++ b/api/app.py @@ -19,7 +19,7 @@ from core.model_providers.providers import hosted from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ - ext_database, ext_storage, ext_mail, ext_stripe + ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension from extensions.ext_database import db from extensions.ext_login import login_manager @@ -79,6 +79,7 @@ def create_app(test_config=None) -> Flask: def initialize_extensions(app): # Since the application instance is now created, pass it to each Flask # extension instance to bind it to the Flask application instance (app) + ext_code_based_extension.init() ext_database.init_app(app) ext_migrate.init(app, db) ext_redis.init_app(app) diff --git a/api/config.py b/api/config.py index d9be647c0ed519..ad099ab15830da 100644 --- a/api/config.py +++ b/api/config.py @@ -57,6 +57,7 @@ 'CLEAN_DAY_SETTING': 30, 'UPLOAD_FILE_SIZE_LIMIT': 15, 'UPLOAD_FILE_BATCH_LIMIT': 5, + 'OUTPUT_MODERATION_BUFFER_SIZE': 300 } @@ -228,6 +229,9 @@ def __init__(self): self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) + # moderation settings + self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE')) + class CloudEditionConfig(Config): diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 2476d918870c6f..ac881dc126c0d0 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -6,7 +6,7 @@ api = ExternalApi(bp) # Import other controllers -from . import setup, version, apikey, admin +from . import extension, setup, version, apikey, admin # Import app controllers from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fb4ce33209ec71..9514834a2be937 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -27,6 +27,7 @@ class AppParameterApi(InstalledAppResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, + 'sensitive_word_avoidance': fields.Raw } @marshal_with(parameters_fields) @@ -42,7 +43,8 @@ def get(self, installed_app: InstalledApp): 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list + 'user_input_form': app_model_config.user_input_form_list, + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict } diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py new file mode 100644 index 00000000000000..50b33e39ad4c9c --- /dev/null +++ b/api/controllers/console/extension.py @@ -0,0 +1,114 @@ +from flask_restful import Resource, reqparse, marshal_with +from flask_login import current_user + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.login import login_required +from models.api_based_extension import APIBasedExtension +from fields.api_based_extension_fields import api_based_extension_fields +from services.code_based_extension_service import CodeBasedExtensionService +from services.api_based_extension_service import APIBasedExtensionService + + +class CodeBasedExtensionAPI(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('module', type=str, required=True, location='args') + args = parser.parse_args() + + return { + 'module': args['module'], + 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) + } + + +class APIBasedExtensionAPI(Resource): + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def get(self): + tenant_id = current_user.current_tenant_id + return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('api_endpoint', type=str, required=True, location='json') + parser.add_argument('api_key', type=str, required=True, location='json') + args = parser.parse_args() + + extension_data = APIBasedExtension( + tenant_id=current_user.current_tenant_id, + name=args['name'], + api_endpoint=args['api_endpoint'], + api_key=args['api_key'] + ) + + return APIBasedExtensionService.save(extension_data) + + +class APIBasedExtensionDetailAPI(Resource): + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def get(self, id): + api_based_extension_id = str(id) + tenant_id = current_user.current_tenant_id + + return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_based_extension_fields) + def post(self, id): + api_based_extension_id = str(id) + tenant_id = current_user.current_tenant_id + + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('api_endpoint', type=str, required=True, location='json') + parser.add_argument('api_key', type=str, required=True, location='json') + args = parser.parse_args() + + extension_data_from_db.name = args['name'] + extension_data_from_db.api_endpoint = args['api_endpoint'] + + if args['api_key'] != '[__HIDDEN__]': + extension_data_from_db.api_key = args['api_key'] + + return APIBasedExtensionService.save(extension_data_from_db) + + @setup_required + @login_required + @account_initialization_required + def delete(self, id): + api_based_extension_id = str(id) + tenant_id = current_user.current_tenant_id + + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + + APIBasedExtensionService.delete(extension_data_from_db) + + return {'result': 'success'} + + +api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') + +api.add_resource(APIBasedExtensionAPI, '/api-based-extension') +api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 86b8642571108b..cca4c1addba504 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -28,6 +28,7 @@ class AppParameterApi(AppApiResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, + 'sensitive_word_avoidance': fields.Raw } @marshal_with(parameters_fields) @@ -42,7 +43,8 @@ def get(self, app_model: App, end_user): 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list + 'user_input_form': app_model_config.user_input_form_list, + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict } diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index a339322ea89a8e..5ab8a7d116ab4a 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -183,4 +183,3 @@ def generate() -> Generator: api.add_resource(CompletionStopApi, '/completion-messages//stop') api.add_resource(ChatApi, '/chat-messages') api.add_resource(ChatStopApi, '/chat-messages//stop') - diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cffda04eea66d5..bb99e26ad165ac 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -27,6 +27,7 @@ class AppParameterApi(WebApiResource): 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, + 'sensitive_word_avoidance': fields.Raw } @marshal_with(parameters_fields) @@ -41,7 +42,8 @@ def get(self, app_model: App, end_user): 'speech_to_text': app_model_config.speech_to_text_dict, 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list + 'user_input_form': app_model_config.user_input_form_list, + 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict } diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 79c0c542d17852..579c1761c548d9 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -139,7 +139,7 @@ def post(self, app_model, end_user, task_id): return {'result': 'success'}, 200 -def compact_response(response: Union[dict | Generator]) -> Response: +def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') else: diff --git a/api/core/__init__.py b/api/core/__init__.py index e69de29bb2d1d6..8c986fc8bd8afa 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -0,0 +1 @@ +import core.moderation.base \ No newline at end of file diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index b8eb99b2e5bfc3..109dda68cdc07b 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -1,13 +1,25 @@ import logging -from typing import Any, Dict, List, Union +import threading +import time +from typing import Any, Dict, List, Union, Optional +from flask import Flask, current_app from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import LLMResult, BaseMessage +from pydantic import BaseModel from core.callback_handler.entity.llm_message import LLMMessage -from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ + ConversationTaskInterruptException from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage from core.model_providers.models.llm.base import BaseLLM +from core.moderation.base import ModerationOutputsResult, ModerationAction +from core.moderation.factory import ModerationFactory + + +class ModerationRule(BaseModel): + type: str + config: Dict[str, Any] class LLMCallbackHandler(BaseCallbackHandler): @@ -20,6 +32,24 @@ def __init__(self, model_instance: BaseLLM, self.start_at = None self.conversation_message_task = conversation_message_task + self.output_moderation_handler = None + self.init_output_moderation() + + def init_output_moderation(self): + app_model_config = self.conversation_message_task.app_model_config + sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict + + if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"): + self.output_moderation_handler = OutputModerationHandler( + tenant_id=self.conversation_message_task.tenant_id, + app_id=self.conversation_message_task.app.id, + rule=ModerationRule( + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config") + ), + on_message_replace_func=self.conversation_message_task.on_message_replace + ) + @property def always_verbose(self) -> bool: """Whether to call verbose callbacks even if verbose is False.""" @@ -59,10 +89,19 @@ def on_llm_start( self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - if not self.conversation_message_task.streaming: - self.conversation_message_task.append_message_text(response.generations[0][0].text) + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() + + self.llm_message.completion = self.output_moderation_handler.moderation_completion( + completion=response.generations[0][0].text, + public_event=True if self.conversation_message_task.streaming else False + ) + else: self.llm_message.completion = response.generations[0][0].text + if not self.conversation_message_task.streaming: + self.conversation_message_task.append_message_text(self.llm_message.completion) + if response.llm_output and 'token_usage' in response.llm_output: if 'prompt_tokens' in response.llm_output['token_usage']: self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] @@ -79,23 +118,161 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.conversation_message_task.save_message(self.llm_message) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + if self.output_moderation_handler and self.output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + ex = ConversationTaskInterruptException() + self.on_llm_error(error=ex) + raise ex + try: self.conversation_message_task.append_message_text(token) + self.llm_message.completion += token + + if self.output_moderation_handler: + self.output_moderation_handler.append_new_token(token) except ConversationTaskStoppedException as ex: self.on_llm_error(error=ex) raise ex - self.llm_message.completion += token - def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() + if isinstance(error, ConversationTaskStoppedException): if self.conversation_message_task.streaming: self.llm_message.completion_tokens = self.model_instance.get_num_tokens( [PromptMessage(content=self.llm_message.completion)] ) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) + if isinstance(error, ConversationTaskInterruptException): + self.llm_message.completion = self.output_moderation_handler.get_final_output() + self.llm_message.completion_tokens = self.model_instance.get_num_tokens( + [PromptMessage(content=self.llm_message.completion)] + ) + self.conversation_message_task.save_message(llm_message=self.llm_message) else: logging.debug("on_llm_error: %s", error) + + +class OutputModerationHandler(BaseModel): + DEFAULT_BUFFER_SIZE: int = 300 + + tenant_id: str + app_id: str + + rule: ModerationRule + on_message_replace_func: Any + + thread: Optional[threading.Thread] = None + thread_running: bool = True + buffer: str = '' + is_final_chunk: bool = False + final_output: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + + def should_direct_output(self): + return self.final_output is not None + + def get_final_output(self): + return self.final_output + + def append_new_token(self, token: str): + self.buffer += token + + if not self.thread: + self.thread = self.start_thread() + + def moderation_completion(self, completion: str, public_event: bool = False) -> str: + self.buffer = completion + self.is_final_chunk = True + + result = self.moderation( + tenant_id=self.tenant_id, + app_id=self.app_id, + moderation_buffer=completion + ) + + if not result or not result.flagged: + return completion + + if result.action == ModerationAction.DIRECT_OUTPUT: + final_output = result.preset_response + else: + final_output = result.text + + if public_event: + self.on_message_replace_func(final_output) + + return final_output + + def start_thread(self) -> threading.Thread: + buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE)) + thread = threading.Thread(target=self.worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE + }) + + thread.start() + + return thread + + def stop_thread(self): + if self.thread and self.thread.is_alive(): + self.thread_running = False + + def worker(self, flask_app: Flask, buffer_size: int): + with flask_app.app_context(): + current_length = 0 + while self.thread_running: + moderation_buffer = self.buffer + buffer_length = len(moderation_buffer) + if not self.is_final_chunk: + chunk_length = buffer_length - current_length + if 0 <= chunk_length < buffer_size: + time.sleep(1) + continue + + current_length = buffer_length + + result = self.moderation( + tenant_id=self.tenant_id, + app_id=self.app_id, + moderation_buffer=moderation_buffer + ) + + if not result or not result.flagged: + continue + + if result.action == ModerationAction.DIRECT_OUTPUT: + final_output = result.preset_response + self.final_output = final_output + else: + final_output = result.text + self.buffer[len(moderation_buffer):] + + # trigger replace event + if self.thread_running: + self.on_message_replace_func(final_output) + + if result.action == ModerationAction.DIRECT_OUTPUT: + break + + def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: + try: + moderation_factory = ModerationFactory( + name=self.rule.type, + app_id=app_id, + tenant_id=tenant_id, + config=self.rule.config + ) + + result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) + return result + except Exception as e: + logging.error("Moderation Output error: %s", e) + + return None diff --git a/api/core/chain/sensitive_word_avoidance_chain.py b/api/core/chain/sensitive_word_avoidance_chain.py deleted file mode 100644 index 62d58542751b2f..00000000000000 --- a/api/core/chain/sensitive_word_avoidance_chain.py +++ /dev/null @@ -1,92 +0,0 @@ -import enum -import logging -from typing import List, Dict, Optional, Any - -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain -from pydantic import BaseModel - -from core.model_providers.error import LLMBadRequestError -from core.model_providers.model_factory import ModelFactory -from core.model_providers.models.llm.base import BaseLLM -from core.model_providers.models.moderation import openai_moderation - - -class SensitiveWordAvoidanceRule(BaseModel): - class Type(enum.Enum): - MODERATION = "moderation" - KEYWORDS = "keywords" - - type: Type - canned_response: str = 'Your content violates our usage policy. Please revise and try again.' - extra_params: dict = {} - - -class SensitiveWordAvoidanceChain(Chain): - input_key: str = "input" #: :meta private: - output_key: str = "output" #: :meta private: - - model_instance: BaseLLM - sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule - - @property - def _chain_type(self) -> str: - return "sensitive_word_avoidance_chain" - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return output key. - - :meta private: - """ - return [self.output_key] - - def _check_sensitive_word(self, text: str) -> bool: - for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []): - if word in text: - return False - return True - - def _check_moderation(self, text: str) -> bool: - moderation_model_instance = ModelFactory.get_moderation_model( - tenant_id=self.model_instance.model_provider.provider.tenant_id, - model_provider_name='openai', - model_name=openai_moderation.DEFAULT_MODEL - ) - - try: - return moderation_model_instance.run(text=text) - except Exception as ex: - logging.exception(ex) - raise LLMBadRequestError('Rate limit exceeded, please try again later.') - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - text = inputs[self.input_key] - - if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS: - result = self._check_sensitive_word(text) - else: - result = self._check_moderation(text) - - if not result: - raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response) - - return {self.output_key: text} - - -class SensitiveWordAvoidanceError(Exception): - def __init__(self, message): - super().__init__(message) - self.message = message diff --git a/api/core/completion.py b/api/core/completion.py index 57e18199271ccb..7eaacd486f9e3c 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,13 +1,18 @@ +import concurrent +import json import logging -from typing import Optional, List, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, List, Union, Tuple +from flask import current_app, Flask from requests.exceptions import ChunkedEncodingError from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler -from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError -from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ + ConversationTaskInterruptException +from core.external_data_tool.factory import ExternalDataToolFactory from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory @@ -18,6 +23,8 @@ from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from models.model import App, AppModelConfig, Account, Conversation, EndUser +from core.moderation.base import ModerationException, ModerationAction +from core.moderation.factory import ModerationFactory class Completion: @@ -76,26 +83,35 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer ) try: - # parse sensitive_word_avoidance_chain chain_callback = MainChainGatherCallbackHandler(conversation_message_task) - sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain( - final_model_instance, [chain_callback]) - if sensitive_word_avoidance_chain: - try: - query = sensitive_word_avoidance_chain.run(query) - except SensitiveWordAvoidanceError as ex: - cls.run_final_llm( - model_instance=final_model_instance, - mode=app.mode, - app_model_config=app_model_config, - query=query, - inputs=inputs, - agent_execute_result=None, - conversation_message_task=conversation_message_task, - memory=memory, - fake_response=ex.message - ) - return + + try: + # process sensitive_word_avoidance + inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query) + except ModerationException as e: + cls.run_final_llm( + model_instance=final_model_instance, + mode=app.mode, + app_model_config=app_model_config, + query=query, + inputs=inputs, + agent_execute_result=None, + conversation_message_task=conversation_message_task, + memory=memory, + fake_response=str(e) + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_model_config.external_data_tools_list + if external_data_tools: + inputs = cls.fill_in_inputs_from_external_data_tools( + tenant_id=app.tenant_id, + app_id=app.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) # get agent executor agent_executor = orchestrator_rule_parser.to_agent_executor( @@ -135,19 +151,110 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer memory=memory, fake_response=fake_response ) - except ConversationTaskStoppedException: + except (ConversationTaskInterruptException, ConversationTaskStoppedException): return except ChunkedEncodingError as e: # Interrupt by LLM (like OpenAI), handle it. logging.warning(f'ChunkedEncodingError: {e}') conversation_message_task.end() return - + + @classmethod + def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): + if not app_model_config.sensitive_word_avoidance_dict['enabled']: + return inputs, query + + type = app_model_config.sensitive_word_avoidance_dict['type'] + + moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) + moderation_result = moderation.moderation_for_inputs(inputs, query) + + if not moderation_result.flagged: + return inputs, query + + if moderation_result.action == ModerationAction.DIRECT_OUTPUT: + raise ModerationException(moderation_result.preset_response) + elif moderation_result.action == ModerationAction.OVERRIDED: + inputs = moderation_result.inputs + query = moderation_result.query + + return inputs, query + + @classmethod + def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict], + inputs: dict, query: str) -> dict: + """ + Fill in variable inputs from external data tools if exists. + + :param tenant_id: workspace id + :param app_id: app id + :param external_data_tools: external data tools configs + :param inputs: the inputs + :param query: the query + :return: the filled inputs + """ + # Group tools by type and config + grouped_tools = {} + for tool in external_data_tools: + if not tool.get("enabled"): + continue + + tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True)) + grouped_tools.setdefault(tool_key, []).append(tool) + + results = {} + with ThreadPoolExecutor() as executor: + futures = {} + for tools in grouped_tools.values(): + # Only query the first tool in each group + first_tool = tools[0] + future = executor.submit( + cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool, + inputs, query + ) + for tool in tools: + futures[future] = tool + + for future in concurrent.futures.as_completed(futures): + tool_key, result = future.result() + if tool_key in grouped_tools: + for tool in grouped_tools[tool_key]: + results[tool['variable']] = result + + inputs.update(results) + return inputs + + @classmethod + def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict, + inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]: + with flask_app.app_context(): + tool_variable = external_data_tool.get("variable") + tool_type = external_data_tool.get("type") + tool_config = external_data_tool.get("config") + + external_data_tool_factory = ExternalDataToolFactory( + name=tool_type, + tenant_id=tenant_id, + app_id=app_id, + variable=tool_variable, + config=tool_config + ) + + # query external data tool + result = external_data_tool_factory.query( + inputs=inputs, + query=query + ) + + tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True)) + + return tool_key, result + @classmethod def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str: if app.mode != 'completion': return query - + return inputs.get(app_model_config.dataset_query_variable, "") @classmethod diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 3be6ffaee37bb0..9dd211d36087df 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -290,6 +290,10 @@ def on_dataset_query_finish(self, resource: List): db.session.commit() self.retriever_resource = resource + def on_message_replace(self, text: str): + if text is not None: + self._pub_handler.pub_message_replace(text) + def message_end(self): self._pub_handler.pub_message_end(self.retriever_resource) @@ -342,6 +346,24 @@ def pub_text(self, text: str): self.pub_end() raise ConversationTaskStoppedException() + def pub_message_replace(self, text: str): + content = { + 'event': 'message_replace', + 'data': { + 'task_id': self._task_id, + 'message_id': str(self._message.id), + 'text': text, + 'mode': self._conversation.mode, + 'conversation_id': str(self._conversation.id) + } + } + + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() + def pub_chain(self, message_chain: MessageChain): if self._chain_pub: content = { @@ -443,3 +465,7 @@ def stop(cls, user: Union[Account | EndUser], task_id: str): class ConversationTaskStoppedException(Exception): pass + + +class ConversationTaskInterruptException(Exception): + pass diff --git a/api/core/extension/__init__.py b/api/core/extension/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py new file mode 100644 index 00000000000000..8ce7edabf23c47 --- /dev/null +++ b/api/core/extension/api_based_extension_requestor.py @@ -0,0 +1,62 @@ +import os + +import requests + +from models.api_based_extension import APIBasedExtensionPoint + + +class APIBasedExtensionRequestor: + timeout: (int, int) = (5, 60) + """timeout for request connect and read""" + + def __init__(self, api_endpoint: str, api_key: str) -> None: + self.api_endpoint = api_endpoint + self.api_key = api_key + + def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: + """ + Request the api. + + :param point: the api point + :param params: the request params + :return: the response json + """ + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(self.api_key) + } + + url = self.api_endpoint + + try: + # proxy support for security + proxies = None + if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"): + proxies = { + 'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"), + 'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"), + } + + response = requests.request( + method='POST', + url=url, + json={ + 'point': point.value, + 'params': params + }, + headers=headers, + timeout=self.timeout, + proxies=proxies + ) + except requests.exceptions.Timeout: + raise ValueError("request timeout") + except requests.exceptions.ConnectionError: + raise ValueError("request connection error") + + if response.status_code != 200: + raise ValueError("request error, status_code: {}, content: {}".format( + response.status_code, + response.text[:100] + )) + + return response.json() diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py new file mode 100644 index 00000000000000..2e879578bf7357 --- /dev/null +++ b/api/core/extension/extensible.py @@ -0,0 +1,111 @@ +import enum +import importlib.util +import json +import logging +import os +from collections import OrderedDict +from typing import Any, Optional + +from pydantic import BaseModel + + +class ExtensionModule(enum.Enum): + MODERATION = 'moderation' + EXTERNAL_DATA_TOOL = 'external_data_tool' + + +class ModuleExtension(BaseModel): + extension_class: Any + name: str + label: Optional[dict] = None + form_schema: Optional[list] = None + builtin: bool = True + position: Optional[int] = None + + +class Extensible: + module: ExtensionModule + + name: str + tenant_id: str + config: Optional[dict] = None + + def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + self.tenant_id = tenant_id + self.config = config + + @classmethod + def scan_extensions(cls): + extensions = {} + + # get the path of the current class + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_dir_path = os.path.dirname(current_path) + + # traverse subdirectories + for subdir_name in os.listdir(current_dir_path): + if subdir_name.startswith('__'): + continue + + subdir_path = os.path.join(current_dir_path, subdir_name) + extension_name = subdir_name + if os.path.isdir(subdir_path): + file_names = os.listdir(subdir_path) + + # is builtin extension, builtin extension + # in the front-end page and business logic, there are special treatments. + builtin = False + position = None + if '__builtin__' in file_names: + builtin = True + + builtin_file_path = os.path.join(subdir_path, '__builtin__') + if os.path.exists(builtin_file_path): + with open(builtin_file_path, 'r') as f: + position = int(f.read().strip()) + + if (extension_name + '.py') not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Dynamic loading {subdir_name}.py file and find the subclass of Extensible + py_path = os.path.join(subdir_path, extension_name + '.py') + spec = importlib.util.spec_from_file_location(extension_name, py_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + extension_class = None + for name, obj in vars(mod).items(): + if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: + extension_class = obj + break + + if not extension_class: + logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + continue + + json_data = {} + if not builtin: + if 'schema.json' not in file_names: + logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") + continue + + json_path = os.path.join(subdir_path, 'schema.json') + json_data = {} + if os.path.exists(json_path): + with open(json_path, 'r') as f: + json_data = json.load(f) + + extensions[extension_name] = ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get('label'), + form_schema=json_data.get('form_schema'), + builtin=builtin, + position=position + ) + + sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) + sorted_extensions = OrderedDict(sorted_items) + + return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py new file mode 100644 index 00000000000000..6517e41ccd1120 --- /dev/null +++ b/api/core/extension/extension.py @@ -0,0 +1,47 @@ +from core.extension.extensible import ModuleExtension, ExtensionModule +from core.external_data_tool.base import ExternalDataTool +from core.moderation.base import Moderation + + +class Extension: + __module_extensions: dict[str, dict[str, ModuleExtension]] = {} + + module_classes = { + ExtensionModule.MODERATION: Moderation, + ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool + } + + def init(self): + for module, module_class in self.module_classes.items(): + self.__module_extensions[module.value] = module_class.scan_extensions() + + def module_extensions(self, module: str) -> list[ModuleExtension]: + module_extensions = self.__module_extensions.get(module) + + if not module_extensions: + raise ValueError(f"Extension Module {module} not found") + + return list(module_extensions.values()) + + def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension: + module_extensions = self.__module_extensions.get(module.value) + + if not module_extensions: + raise ValueError(f"Extension Module {module} not found") + + module_extension = module_extensions.get(extension_name) + + if not module_extension: + raise ValueError(f"Extension {extension_name} not found") + + return module_extension + + def extension_class(self, module: ExtensionModule, extension_name: str) -> type: + module_extension = self.module_extension(module, extension_name) + return module_extension.extension_class + + def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: + module_extension = self.module_extension(module, extension_name) + form_schema = module_extension.form_schema + + # TODO validate form_schema diff --git a/api/core/external_data_tool/__init__.py b/api/core/external_data_tool/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/external_data_tool/api/__builtin__ b/api/core/external_data_tool/api/__builtin__ new file mode 100644 index 00000000000000..56a6051ca2b02b --- /dev/null +++ b/api/core/external_data_tool/api/__builtin__ @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/api/core/external_data_tool/api/__init__.py b/api/core/external_data_tool/api/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py new file mode 100644 index 00000000000000..8896a00699e7cf --- /dev/null +++ b/api/core/external_data_tool/api/api.py @@ -0,0 +1,92 @@ +from typing import Optional + +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from core.external_data_tool.base import ExternalDataTool +from core.helper import encrypter +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint + + +class ApiExternalDataTool(ExternalDataTool): + """ + The api external data tool. + """ + + name: str = "api" + """the unique name of external data tool""" + + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + # own validation logic + api_based_extension_id = config.get("api_based_extension_id") + if not api_based_extension_id: + raise ValueError("api_based_extension_id is required") + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("api_based_extension_id is invalid") + + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + # get params from config + api_based_extension_id = self.config.get("api_based_extension_id") + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == self.tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid" + .format(self.config.get('variable'))) + + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=self.tenant_id, + token=api_based_extension.api_key + ) + + try: + # request api + requestor = APIBasedExtensionRequestor( + api_endpoint=api_based_extension.api_endpoint, + api_key=api_key + ) + except Exception as e: + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( + self.config.get('variable'), + e + )) + + response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ + 'app_id': self.app_id, + 'tool_variable': self.variable, + 'inputs': inputs, + 'query': query + }) + + if 'result' not in response_json: + raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" + .format(self.config.get('variable'))) + + return response_json['result'] diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py new file mode 100644 index 00000000000000..1c181ff3c56c53 --- /dev/null +++ b/api/core/external_data_tool/base.py @@ -0,0 +1,45 @@ +from abc import abstractmethod, ABC +from typing import Optional + +from core.extension.extensible import Extensible, ExtensionModule + + +class ExternalDataTool(Extensible, ABC): + """ + The base class of external data tool. + """ + + module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL + + app_id: str + """the id of app""" + variable: str + """the tool variable name of app tool""" + + def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: + super().__init__(tenant_id, config) + self.app_id = app_id + self.variable = variable + + @classmethod + @abstractmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + raise NotImplementedError + + @abstractmethod + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + raise NotImplementedError diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py new file mode 100644 index 00000000000000..979f243af65f61 --- /dev/null +++ b/api/core/external_data_tool/factory.py @@ -0,0 +1,40 @@ +from typing import Optional + +from core.extension.extensible import ExtensionModule +from extensions.ext_code_based_extension import code_based_extension + + +class ExternalDataToolFactory: + + def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: + extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + self.__extension_instance = extension_class( + tenant_id=tenant_id, + app_id=app_id, + variable=variable, + config=config + ) + + @classmethod + def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param name: the name of external data tool + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) + extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) + extension_class.validate_config(tenant_id, config) + + def query(self, inputs: dict, query: Optional[str] = None) -> str: + """ + Query the external data tool. + + :param inputs: user inputs + :param query: the query of chat app + :return: the tool query result + """ + return self.__extension_instance.query(inputs, query) diff --git a/api/core/moderation/__init__.py b/api/core/moderation/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/api/__builtin__ b/api/core/moderation/api/__builtin__ new file mode 100644 index 00000000000000..e440e5c8425869 --- /dev/null +++ b/api/core/moderation/api/__builtin__ @@ -0,0 +1 @@ +3 \ No newline at end of file diff --git a/api/core/moderation/api/__init__.py b/api/core/moderation/api/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py new file mode 100644 index 00000000000000..9ef584cd1ad5b4 --- /dev/null +++ b/api/core/moderation/api/api.py @@ -0,0 +1,88 @@ +from pydantic import BaseModel + +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint +from core.helper.encrypter import decrypt_token +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension + + +class ModerationInputParams(BaseModel): + app_id: str = "" + inputs: dict = {} + query: str = "" + + +class ModerationOutputParams(BaseModel): + app_id: str = "" + text: str + + +class ApiModeration(Moderation): + name: str = "api" + + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + cls._validate_inputs_and_outputs_config(config, False) + + api_based_extension_id = config.get("api_based_extension_id") + if not api_based_extension_id: + raise ValueError("api_based_extension_id is required") + + extension = cls._get_api_based_extension(tenant_id, api_based_extension_id) + if not extension: + raise ValueError("API-based Extension not found. Please check it again.") + + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + flagged = False + preset_response = "" + + if self.config['inputs_config']['enabled']: + params = ModerationInputParams( + app_id=self.app_id, + inputs=inputs, + query=query + ) + + result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict()) + return ModerationInputsResult(**result) + + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: + flagged = False + preset_response = "" + + if self.config['outputs_config']['enabled']: + params = ModerationOutputParams( + app_id=self.app_id, + text=text + ) + + result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict()) + return ModerationOutputsResult(**result) + + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) + + result = requestor.request(extension_point, params) + return result + + @staticmethod + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py new file mode 100644 index 00000000000000..ce4e574038d05a --- /dev/null +++ b/api/core/moderation/base.py @@ -0,0 +1,113 @@ +from abc import ABC, abstractmethod +from typing import Optional +from pydantic import BaseModel +from enum import Enum + +from core.extension.extensible import Extensible, ExtensionModule + + +class ModerationAction(Enum): + DIRECT_OUTPUT = 'direct_output' + OVERRIDED = 'overrided' + + +class ModerationInputsResult(BaseModel): + flagged: bool = False + action: ModerationAction + preset_response: str = "" + inputs: dict = {} + query: str = "" + + +class ModerationOutputsResult(BaseModel): + flagged: bool = False + action: ModerationAction + preset_response: str = "" + text: str = "" + + +class Moderation(Extensible, ABC): + """ + The base class of moderation. + """ + module: ExtensionModule = ExtensionModule.MODERATION + + def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: + super().__init__(tenant_id, config) + self.app_id = app_id + + @classmethod + @abstractmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + raise NotImplementedError + + @abstractmethod + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + """ + Moderation for inputs. + After the user inputs, this method will be called to perform sensitive content review + on the user inputs and return the processed results. + + :param inputs: user inputs + :param query: query string (required in chat app) + :return: + """ + raise NotImplementedError + + @abstractmethod + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: + """ + Moderation for outputs. + When LLM outputs content, the front end will pass the output content (may be segmented) + to this method for sensitive content review, and the output content will be shielded if the review fails. + + :param text: LLM output content + :return: + """ + raise NotImplementedError + + @classmethod + def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + # inputs_config + inputs_config = config.get("inputs_config") + if not isinstance(inputs_config, dict): + raise ValueError("inputs_config must be a dict") + + # outputs_config + outputs_config = config.get("outputs_config") + if not isinstance(outputs_config, dict): + raise ValueError("outputs_config must be a dict") + + inputs_config_enabled = inputs_config.get("enabled") + outputs_config_enabled = outputs_config.get("enabled") + if not inputs_config_enabled and not outputs_config_enabled: + raise ValueError("At least one of inputs_config or outputs_config must be enabled") + + # preset_response + if not is_preset_response_required: + return + + if inputs_config_enabled: + if not inputs_config.get("preset_response"): + raise ValueError("inputs_config.preset_response is required") + + if len(inputs_config.get("preset_response")) > 100: + raise ValueError("inputs_config.preset_response must be less than 100 characters") + + if outputs_config_enabled: + if not outputs_config.get("preset_response"): + raise ValueError("outputs_config.preset_response is required") + + if len(outputs_config.get("preset_response")) > 100: + raise ValueError("outputs_config.preset_response must be less than 100 characters") + + +class ModerationException(Exception): + pass diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py new file mode 100644 index 00000000000000..96bf2ab54b41eb --- /dev/null +++ b/api/core/moderation/factory.py @@ -0,0 +1,48 @@ +from core.extension.extensible import ExtensionModule +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult +from extensions.ext_code_based_extension import code_based_extension + + +class ModerationFactory: + __extension_instance: Moderation + + def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None: + extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + self.__extension_instance = extension_class(app_id, tenant_id, config) + + @classmethod + def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param name: the name of extension + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) + extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) + extension_class.validate_config(tenant_id, config) + + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + """ + Moderation for inputs. + After the user inputs, this method will be called to perform sensitive content review + on the user inputs and return the processed results. + + :param inputs: user inputs + :param query: query string (required in chat app) + :return: + """ + return self.__extension_instance.moderation_for_inputs(inputs, query) + + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: + """ + Moderation for outputs. + When LLM outputs content, the front end will pass the output content (may be segmented) + to this method for sensitive content review, and the output content will be shielded if the review fails. + + :param text: LLM output content + :return: + """ + return self.__extension_instance.moderation_for_outputs(text) diff --git a/api/core/moderation/keywords/__builtin__ b/api/core/moderation/keywords/__builtin__ new file mode 100644 index 00000000000000..d8263ee9860594 --- /dev/null +++ b/api/core/moderation/keywords/__builtin__ @@ -0,0 +1 @@ +2 \ No newline at end of file diff --git a/api/core/moderation/keywords/__init__.py b/api/core/moderation/keywords/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py new file mode 100644 index 00000000000000..168b9d43f806f7 --- /dev/null +++ b/api/core/moderation/keywords/keywords.py @@ -0,0 +1,60 @@ +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction + + +class KeywordsModeration(Moderation): + name: str = "keywords" + + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + cls._validate_inputs_and_outputs_config(config, True) + + if not config.get("keywords"): + raise ValueError("keywords is required") + + if len(config.get("keywords")) > 1000: + raise ValueError("keywords length must be less than 1000") + + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + flagged = False + preset_response = "" + + if self.config['inputs_config']['enabled']: + preset_response = self.config['inputs_config']['preset_response'] + + if query: + inputs['query__'] = query + keywords_list = self.config['keywords'].split('\n') + flagged = self._is_violated(inputs, keywords_list) + + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: + flagged = False + preset_response = "" + + if self.config['outputs_config']['enabled']: + keywords_list = self.config['keywords'].split('\n') + flagged = self._is_violated({'text': text}, keywords_list) + preset_response = self.config['outputs_config']['preset_response'] + + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def _is_violated(self, inputs: dict, keywords_list: list) -> bool: + for value in inputs.values(): + if self._check_keywords_in_value(keywords_list, value): + return True + + return False + + def _check_keywords_in_value(self, keywords_list, value): + for keyword in keywords_list: + if keyword.lower() in value.lower(): + return True + return False diff --git a/api/core/moderation/openai_moderation/__builtin__ b/api/core/moderation/openai_moderation/__builtin__ new file mode 100644 index 00000000000000..56a6051ca2b02b --- /dev/null +++ b/api/core/moderation/openai_moderation/__builtin__ @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/api/core/moderation/openai_moderation/__init__.py b/api/core/moderation/openai_moderation/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py new file mode 100644 index 00000000000000..c5817b19011f59 --- /dev/null +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -0,0 +1,46 @@ +from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction +from core.model_providers.model_factory import ModelFactory + + +class OpenAIModeration(Moderation): + name: str = "openai_moderation" + + @classmethod + def validate_config(cls, tenant_id: str, config: dict) -> None: + """ + Validate the incoming form config data. + + :param tenant_id: the id of workspace + :param config: the form config data + :return: + """ + cls._validate_inputs_and_outputs_config(config, True) + + def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + flagged = False + preset_response = "" + + if self.config['inputs_config']['enabled']: + preset_response = self.config['inputs_config']['preset_response'] + + if query: + inputs['query__'] = query + flagged = self._is_violated(inputs) + + return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: + flagged = False + preset_response = "" + + if self.config['outputs_config']['enabled']: + flagged = self._is_violated({'text': text}) + preset_response = self.config['outputs_config']['preset_response'] + + return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + + def _is_violated(self, inputs: dict): + text = '\n'.join(inputs.values()) + openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation") + is_not_invalid = openai_moderation.run(text) + return not is_not_invalid diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 2ba732ee3dd613..d057c160a2a4a5 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -11,7 +11,6 @@ from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule from core.conversation_message_task import ConversationMessageTask from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory @@ -125,52 +124,6 @@ def to_agent_executor(self, conversation_message_task: ConversationMessageTask, return chain - def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \ - -> Optional[SensitiveWordAvoidanceChain]: - """ - Convert app sensitive word avoidance config to chain - - :param model_instance: model instance - :param callbacks: callbacks for the chain - :param kwargs: - :return: - """ - sensitive_word_avoidance_rule = None - - if self.app_model_config.sensitive_word_avoidance_dict: - sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict - if sensitive_word_avoidance_config.get("enabled", False): - if sensitive_word_avoidance_config.get('type') == 'moderation': - sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( - type=SensitiveWordAvoidanceRule.Type.MODERATION, - canned_response=sensitive_word_avoidance_config.get("canned_response") - if sensitive_word_avoidance_config.get("canned_response") - else 'Your content violates our usage policy. Please revise and try again.', - ) - else: - sensitive_words = sensitive_word_avoidance_config.get("words", "") - if sensitive_words: - sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( - type=SensitiveWordAvoidanceRule.Type.KEYWORDS, - canned_response=sensitive_word_avoidance_config.get("canned_response") - if sensitive_word_avoidance_config.get("canned_response") - else 'Your content violates our usage policy. Please revise and try again.', - extra_params={ - 'sensitive_words': sensitive_words.split(','), - } - ) - - if sensitive_word_avoidance_rule: - return SensitiveWordAvoidanceChain( - model_instance=model_instance, - sensitive_word_avoidance_rule=sensitive_word_avoidance_rule, - output_key="sensitive_word_avoidance_output", - callbacks=callbacks, - **kwargs - ) - - return None - def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: """ Convert app agent tool configs to tools diff --git a/api/extensions/ext_code_based_extension.py b/api/extensions/ext_code_based_extension.py new file mode 100644 index 00000000000000..a8ae733aa69927 --- /dev/null +++ b/api/extensions/ext_code_based_extension.py @@ -0,0 +1,8 @@ +from core.extension.extension import Extension + + +def init(): + code_based_extension.init() + + +code_based_extension = Extension() diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py new file mode 100644 index 00000000000000..f6a60ff02e83bb --- /dev/null +++ b/api/fields/api_based_extension_fields.py @@ -0,0 +1,17 @@ +from flask_restful import fields + +from libs.helper import TimestampField + + +class HiddenAPIKey(fields.Raw): + def output(self, key, obj): + return obj.api_key[:3] + '***' + obj.api_key[-3:] + + +api_based_extension_fields = { + 'id': fields.String, + 'name': fields.String, + 'api_endpoint': fields.String, + 'api_key': HiddenAPIKey, + 'created_at': TimestampField +} diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index fccfa5df306f5d..154b10ceb6f3d3 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -23,6 +23,7 @@ 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), + 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), 'model': fields.Raw(attribute='model_dict'), 'user_input_form': fields.Raw(attribute='user_input_form_list'), 'dataset_query_variable': fields.String, diff --git a/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py new file mode 100644 index 00000000000000..57b28e707f3b2b --- /dev/null +++ b/api/migrations/versions/968fff4c0ab9_add_api_based_extension.py @@ -0,0 +1,45 @@ +"""add_api_based_extension + +Revision ID: 968fff4c0ab9 +Revises: b3a09c049e8e +Create Date: 2023-10-27 13:05:58.901858 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '968fff4c0ab9' +down_revision = 'b3a09c049e8e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.create_table('api_based_extensions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('api_endpoint', sa.String(length=255), nullable=False), + sa.Column('api_key', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_based_extension_pkey') + ) + with op.batch_alter_table('api_based_extensions', schema=None) as batch_op: + batch_op.create_index('api_based_extension_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('api_based_extensions', schema=None) as batch_op: + batch_op.drop_index('api_based_extension_tenant_idx') + + op.drop_table('api_based_extensions') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py new file mode 100644 index 00000000000000..9b452f75eed6d4 --- /dev/null +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -0,0 +1,32 @@ +"""add external_data_tools in app model config + +Revision ID: a9836e3baeee +Revises: 968fff4c0ab9 +Create Date: 2023-11-02 04:04:57.609485 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'a9836e3baeee' +down_revision = '968fff4c0ab9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('external_data_tools') + + # ### end Alembic commands ### diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py new file mode 100644 index 00000000000000..e34cfb8f7b3371 --- /dev/null +++ b/api/models/api_based_extension.py @@ -0,0 +1,27 @@ +import enum + +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db + + +class APIBasedExtensionPoint(enum.Enum): + APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' + PING = 'ping' + APP_MODERATION_INPUT = 'app.moderation.input' + APP_MODERATION_OUTPUT = 'app.moderation.output' + + +class APIBasedExtension(db.Model): + __tablename__ = 'api_based_extensions' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), + db.Index('api_based_extension_tenant_idx', 'tenant_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + api_endpoint = db.Column(db.String(255), nullable=False) + api_key = db.Column(db.Text, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/model.py b/api/models/model.py index d3f5c8135f1f99..cefc275e35ccb9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -97,6 +97,7 @@ class AppModelConfig(db.Model): chat_prompt_config = db.Column(db.Text) completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) + external_data_tools = db.Column(db.Text) @property def app(self): @@ -133,7 +134,12 @@ def more_like_this_dict(self) -> dict: @property def sensitive_word_avoidance_dict(self) -> dict: return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ - else {"enabled": False, "words": [], "canned_response": []} + else {"enabled": False, "type": "", "configs": []} + + @property + def external_data_tools_list(self) -> list[dict]: + return json.loads(self.external_data_tools) if self.external_data_tools \ + else [] @property def user_input_form_list(self) -> dict: @@ -167,6 +173,7 @@ def to_dict(self) -> dict: "retriever_resource": self.retriever_resource_dict, "more_like_this": self.more_like_this_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, + "external_data_tools": self.external_data_tools_list, "model": self.model_dict, "user_input_form": self.user_input_form_list, "dataset_query_variable": self.dataset_query_variable, @@ -190,6 +197,7 @@ def from_model_config_dict(self, model_config: dict): self.more_like_this = json.dumps(model_config['more_like_this']) self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ if model_config.get('sensitive_word_avoidance') else None + self.external_data_tools = json.dumps(model_config['external_data_tools']) self.model = json.dumps(model_config['model']) self.user_input_form = json.dumps(model_config['user_input_form']) self.dataset_query_variable = model_config.get('dataset_query_variable') @@ -219,6 +227,7 @@ def copy(self): speech_to_text=self.speech_to_text, more_like_this=self.more_like_this, sensitive_word_avoidance=self.sensitive_word_avoidance, + external_data_tools=self.external_data_tools, model=self.model, user_input_form=self.user_input_form, dataset_query_variable=self.dataset_query_variable, @@ -332,41 +341,16 @@ def model_config(self): override_model_configs = json.loads(self.override_model_configs) if 'model' in override_model_configs: - model_config['model'] = override_model_configs['model'] - model_config['pre_prompt'] = override_model_configs['pre_prompt'] - model_config['agent_mode'] = override_model_configs['agent_mode'] - model_config['opening_statement'] = override_model_configs['opening_statement'] - model_config['suggested_questions'] = override_model_configs['suggested_questions'] - model_config['suggested_questions_after_answer'] = override_model_configs[ - 'suggested_questions_after_answer'] \ - if 'suggested_questions_after_answer' in override_model_configs else {"enabled": False} - model_config['speech_to_text'] = override_model_configs[ - 'speech_to_text'] \ - if 'speech_to_text' in override_model_configs else {"enabled": False} - model_config['more_like_this'] = override_model_configs['more_like_this'] \ - if 'more_like_this' in override_model_configs else {"enabled": False} - model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \ - if 'sensitive_word_avoidance' in override_model_configs \ - else {"enabled": False, "words": [], "canned_response": []} - model_config['user_input_form'] = override_model_configs['user_input_form'] + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(override_model_configs) + model_config = app_model_config.to_dict() else: model_config['configs'] = override_model_configs else: app_model_config = db.session.query(AppModelConfig).filter( AppModelConfig.id == self.app_model_config_id).first() - model_config['configs'] = app_model_config.configs - model_config['model'] = app_model_config.model_dict - model_config['pre_prompt'] = app_model_config.pre_prompt - model_config['agent_mode'] = app_model_config.agent_mode_dict - model_config['opening_statement'] = app_model_config.opening_statement - model_config['suggested_questions'] = app_model_config.suggested_questions_list - model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict - model_config['speech_to_text'] = app_model_config.speech_to_text_dict - model_config['retriever_resource'] = app_model_config.retriever_resource_dict - model_config['more_like_this'] = app_model_config.more_like_this_dict - model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict - model_config['user_input_form'] = app_model_config.user_input_form_list + model_config = app_model_config.to_dict() model_config['model_id'] = self.model_id model_config['provider'] = self.model_provider diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py new file mode 100644 index 00000000000000..d4e7d5be3d03f2 --- /dev/null +++ b/api/services/api_based_extension_service.py @@ -0,0 +1,98 @@ +from extensions.ext_database import db +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from core.helper.encrypter import encrypt_token, decrypt_token +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor + + +class APIBasedExtensionService: + + @staticmethod + def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: + extension_list = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=tenant_id) \ + .order_by(APIBasedExtension.created_at.desc()) \ + .all() + + for extension in extension_list: + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) + + return extension_list + + @classmethod + def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: + cls._validation(extension_data) + + extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) + + db.session.add(extension_data) + db.session.commit() + return extension_data + + @staticmethod + def delete(extension_data: APIBasedExtension) -> None: + db.session.delete(extension_data) + db.session.commit() + + @staticmethod + def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=tenant_id) \ + .filter_by(id=api_based_extension_id) \ + .first() + + if not extension: + raise ValueError("API based extension is not found") + + extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) + + return extension + + @classmethod + def _validation(cls, extension_data: APIBasedExtension) -> None: + # name + if not extension_data.name: + raise ValueError("name must not be empty") + + if not extension_data.id: + # case one: check new data, name must be unique + is_name_existed = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=extension_data.tenant_id) \ + .filter_by(name=extension_data.name) \ + .first() + + if is_name_existed: + raise ValueError("name must be unique, it is already existed") + else: + # case two: check existing data, name must be unique + is_name_existed = db.session.query(APIBasedExtension) \ + .filter_by(tenant_id=extension_data.tenant_id) \ + .filter_by(name=extension_data.name) \ + .filter(APIBasedExtension.id != extension_data.id) \ + .first() + + if is_name_existed: + raise ValueError("name must be unique, it is already existed") + + # api_endpoint + if not extension_data.api_endpoint: + raise ValueError("api_endpoint must not be empty") + + # api_key + if not extension_data.api_key: + raise ValueError("api_key must not be empty") + + if len(extension_data.api_key) < 5: + raise ValueError("api_key must be at least 5 characters") + + # check endpoint + cls._ping_connection(extension_data) + + @staticmethod + def _ping_connection(extension_data: APIBasedExtension) -> None: + try: + client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) + resp = client.request(point=APIBasedExtensionPoint.PING, params={}) + if resp.get('result') != 'pong': + raise ValueError(resp) + except Exception as e: + raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 79c1ed0ad6f663..b9b4e84735f29a 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,6 +1,8 @@ import re import uuid +from core.external_data_tool.factory import ExternalDataToolFactory +from core.moderation.factory import ModerationFactory from core.prompt.prompt_transform import AppMode from core.agent.agent_executor import PlanningStrategy from core.model_providers.model_provider_factory import ModelProviderFactory @@ -13,8 +15,8 @@ class AppModelConfigService: - @staticmethod - def is_dataset_exists(account: Account, dataset_id: str) -> bool: + @classmethod + def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool: # verify if the dataset ID exists dataset = DatasetService.get_dataset(dataset_id) @@ -26,8 +28,8 @@ def is_dataset_exists(account: Account, dataset_id: str) -> bool: return True - @staticmethod - def validate_model_completion_params(cp: dict, model_name: str) -> dict: + @classmethod + def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict: # 6. model.completion_params if not isinstance(cp, dict): raise ValueError("model.completion_params must be of object type") @@ -57,7 +59,7 @@ def validate_model_completion_params(cp: dict, model_name: str) -> dict: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") - + if len(cp["stop"]) > 4: raise ValueError("stop sequences must be less than 4") @@ -73,8 +75,8 @@ def validate_model_completion_params(cp: dict, model_name: str) -> dict: return filtered_cp - @staticmethod - def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict: + @classmethod + def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict: # opening_statement if 'opening_statement' not in config or not config["opening_statement"]: config["opening_statement"] = "" @@ -153,33 +155,6 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: if not isinstance(config["more_like_this"]["enabled"], bool): raise ValueError("enabled in more_like_this must be of boolean type") - # sensitive_word_avoidance - if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: - config["sensitive_word_avoidance"] = { - "enabled": False - } - - if not isinstance(config["sensitive_word_avoidance"], dict): - raise ValueError("sensitive_word_avoidance must be of dict type") - - if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: - config["sensitive_word_avoidance"]["enabled"] = False - - if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool): - raise ValueError("enabled in sensitive_word_avoidance must be of boolean type") - - if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]: - config["sensitive_word_avoidance"]["words"] = "" - - if not isinstance(config["sensitive_word_avoidance"]["words"], str): - raise ValueError("words in sensitive_word_avoidance must be of string type") - - if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]: - config["sensitive_word_avoidance"]["canned_response"] = "" - - if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str): - raise ValueError("canned_response in sensitive_word_avoidance must be of string type") - # model if 'model' not in config: raise ValueError("model is required") @@ -204,7 +179,7 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: model_ids = [m['id'] for m in model_list] if config["model"]["name"] not in model_ids: raise ValueError("model.name must be in the specified model list") - + # model.mode if 'mode' not in config['model'] or not config['model']["mode"]: config['model']["mode"] = "" @@ -213,7 +188,7 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: if 'completion_params' not in config["model"]: raise ValueError("model.completion_params is required") - config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params( + config["model"]["completion_params"] = cls.validate_model_completion_params( config["model"]["completion_params"], config["model"]["name"] ) @@ -330,14 +305,20 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: except ValueError: raise ValueError("id in dataset must be of UUID type") - if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]): + if not cls.is_dataset_exists(account, tool_item["id"]): raise ValueError("Dataset ID does not exist, please check your permission.") - + # dataset_query_variable - AppModelConfigService.is_dataset_query_variable_valid(config, mode) + cls.is_dataset_query_variable_valid(config, mode) # advanced prompt validation - AppModelConfigService.is_advanced_prompt_valid(config, mode) + cls.is_advanced_prompt_valid(config, mode) + + # external data tools validation + cls.is_external_data_tools_valid(tenant_id, config) + + # moderation validation + cls.is_moderation_valid(tenant_id, config) # Filter out extra parameters filtered_config = { @@ -348,6 +329,7 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: "retriever_resource": config["retriever_resource"], "more_like_this": config["more_like_this"], "sensitive_word_avoidance": config["sensitive_word_avoidance"], + "external_data_tools": config["external_data_tools"], "model": { "provider": config["model"]["provider"], "name": config["model"]["name"], @@ -365,32 +347,86 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: } return filtered_config - - @staticmethod - def is_dataset_query_variable_valid(config: dict, mode: str) -> None: + + @classmethod + def is_moderation_valid(cls, tenant_id: str, config: dict): + if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: + config["sensitive_word_avoidance"] = { + "enabled": False + } + + if not isinstance(config["sensitive_word_avoidance"], dict): + raise ValueError("sensitive_word_avoidance must be of dict type") + + if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: + config["sensitive_word_avoidance"]["enabled"] = False + + if not config["sensitive_word_avoidance"]["enabled"]: + return + + if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: + raise ValueError("sensitive_word_avoidance.type is required") + + type = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["config"] + + ModerationFactory.validate_config( + name=type, + tenant_id=tenant_id, + config=config + ) + + @classmethod + def is_external_data_tools_valid(cls, tenant_id: str, config: dict): + if 'external_data_tools' not in config or not config["external_data_tools"]: + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + type = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config( + name=type, + tenant_id=tenant_id, + config=config + ) + + @classmethod + def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: # Only check when mode is completion if mode != 'completion': return - + agent_mode = config.get("agent_mode", {}) tools = agent_mode.get("tools", []) dataset_exists = "dataset" in str(tools) - + dataset_query_variable = config.get("dataset_query_variable") if dataset_exists and not dataset_query_variable: raise ValueError("Dataset query variable is required when dataset is exist") - - @staticmethod - def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: + @classmethod + def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type if 'prompt_type' not in config or not config["prompt_type"]: config["prompt_type"] = "simple" if config['prompt_type'] not in ['simple', 'advanced']: raise ValueError("prompt_type must be in ['simple', 'advanced']") - + # chat_prompt_config if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: config["chat_prompt_config"] = {} @@ -404,7 +440,7 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - + # dataset_configs if 'dataset_configs' not in config or not config["dataset_configs"]: config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} @@ -415,10 +451,10 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: if config['prompt_type'] == 'advanced': if not config['chat_prompt_config'] and not config['completion_prompt_config']: raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") - + if config['model']["mode"] not in ['chat', 'completion']: raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") - + if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value: user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] @@ -429,9 +465,8 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: if not assistant_prefix: config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' - if config['model']["mode"] == ModelMode.CHAT.value: prompt_list = config['chat_prompt_config']['prompt'] if len(prompt_list) > 10: - raise ValueError("prompt messages must be less than 10") \ No newline at end of file + raise ValueError("prompt messages must be less than 10") diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py new file mode 100644 index 00000000000000..7b0d50a835dbf8 --- /dev/null +++ b/api/services/code_based_extension_service.py @@ -0,0 +1,13 @@ +from extensions.ext_code_based_extension import code_based_extension + + +class CodeBasedExtensionService: + + @staticmethod + def get_code_based_extension(module: str) -> list[dict]: + module_extensions = code_based_extension.module_extensions(module) + return [{ + 'name': module_extension.name, + 'label': module_extension.label, + 'form_schema': module_extension.form_schema + } for module_extension in module_extensions if not module_extension.builtin] diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 54b150d155fb82..81e35c85939770 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -10,7 +10,8 @@ from sqlalchemy import and_ from core.completion import Completion -from core.conversation_message_task import PubHandler, ConversationTaskStoppedException +from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \ + ConversationTaskInterruptException from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ LLMRateLimitError, \ LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError @@ -28,9 +29,9 @@ class CompletionService: @classmethod - def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any, + def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, from_source: str, streaming: bool = True, - is_model_config_override: bool = False) -> Union[dict | Generator]: + is_model_config_override: bool = False) -> Union[dict, Generator]: # is streaming mode inputs = args['inputs'] query = args['query'] @@ -199,9 +200,9 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m is_override=is_model_config_override, retriever_from=retriever_from ) - except ConversationTaskStoppedException: + except (ConversationTaskInterruptException, ConversationTaskStoppedException): pass - except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError) as e: PubHandler.pub_error(user, generate_task_id, e) @@ -234,7 +235,7 @@ def close_pubsub(): PubHandler.stop(user, generate_task_id) try: pubsub.close() - except: + except Exception: pass countdown_thread = threading.Thread(target=close_pubsub) @@ -243,9 +244,9 @@ def close_pubsub(): return countdown_thread @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], + def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, streaming: bool = True, - retriever_from: str = 'dev') -> Union[dict | Generator]: + retriever_from: str = 'dev') -> Union[dict, Generator]: if not user: raise ValueError('user cannot be None') @@ -341,7 +342,7 @@ def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig) return filtered_inputs @classmethod - def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]: + def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]: generate_channel = list(pubsub.channels.keys())[0].decode('utf-8') if not streaming: try: @@ -386,6 +387,8 @@ def generate() -> Generator: break if event == 'message': yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" + elif event == 'message_replace': + yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n" elif event == 'chain': yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" elif event == 'agent_thought': @@ -427,6 +430,21 @@ def get_message_response_data(cls, data: dict): return response_data + @classmethod + def get_message_replace_response_data(cls, data: dict): + response_data = { + 'event': 'message_replace', + 'task_id': data.get('task_id'), + 'id': data.get('message_id'), + 'answer': data.get('text'), + 'created_at': int(time.time()) + } + + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + @classmethod def get_blocking_message_response_data(cls, data: dict): message = data.get('message') @@ -508,6 +526,7 @@ def handle_error(cls, result: dict): # handle errors llm_errors = { + 'ValueError': LLMBadRequestError, 'LLMBadRequestError': LLMBadRequestError, 'LLMAPIConnectionError': LLMAPIConnectionError, 'LLMAPIUnavailableError': LLMAPIUnavailableError, diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py new file mode 100644 index 00000000000000..d0be7bca80aeb1 --- /dev/null +++ b/api/services/moderation_service.py @@ -0,0 +1,20 @@ +from models.model import AppModelConfig, App +from core.moderation.factory import ModerationFactory, ModerationOutputsResult +from extensions.ext_database import db + + +class ModerationService: + + def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: + app_model_config: AppModelConfig = None + + app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + + if not app_model_config: + raise ValueError("app model config not found") + + name = app_model_config.sensitive_word_avoidance_dict['type'] + config = app_model_config.sensitive_word_avoidance_dict['config'] + + moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) + return moderation.moderation_for_outputs(text)