From 54ff03c35d391aae9d0c1ad7bae874460c3f48b0 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Wed, 27 Sep 2023 15:24:54 +0800 Subject: [PATCH 1/3] fix: dataset query error. (#1244) --- api/models/model.py | 2 +- api/services/app_model_config_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/models/model.py b/api/models/model.py index 5fb2abe71c1c08..dd7f09bf78a1db 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -172,7 +172,7 @@ def from_model_config_dict(self, model_config: dict): if model_config.get('sensitive_word_avoidance') else None self.model = json.dumps(model_config['model']) self.user_input_form = json.dumps(model_config['user_input_form']) - self.dataset_query_variable = model_config['dataset_query_variable'] + self.dataset_query_variable = model_config.get('dataset_query_variable') self.pre_prompt = model_config['pre_prompt'] self.agent_mode = json.dumps(model_config['agent_mode']) self.retriever_resource = json.dumps(model_config['retriever_resource']) \ diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index f3334ca9b99aa3..916a1078e5baf7 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -354,7 +354,7 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: "completion_params": config["model"]["completion_params"] }, "user_input_form": config["user_input_form"], - "dataset_query_variable": config["dataset_query_variable"], + "dataset_query_variable": config.get('dataset_query_variable'), "pre_prompt": config["pre_prompt"], "agent_mode": config["agent_mode"] } From 46154c67056b86450a84853694e1dde64dfc8601 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:06:32 +0800 Subject: [PATCH 2/3] Feat/dataset service api (#1245) Co-authored-by: jyong Co-authored-by: StyleZhang --- api/controllers/console/apikey.py | 1 + api/controllers/console/app/app.py | 105 +---- api/controllers/console/app/conversation.py | 157 +------ api/controllers/console/app/message.py | 39 +- api/controllers/console/app/site.py | 17 +- .../console/datasets/data_source.py | 54 +-- api/controllers/console/datasets/datasets.py | 164 +++++--- .../console/datasets/datasets_document.py | 101 +---- .../console/datasets/datasets_segments.py | 33 +- api/controllers/console/datasets/file.py | 98 +---- .../console/datasets/hit_testing.py | 42 +- .../console/explore/conversation.py | 18 +- .../console/explore/installed_app.py | 23 +- api/controllers/console/explore/message.py | 40 +- .../console/universal_chat/conversation.py | 22 +- api/controllers/service_api/__init__.py | 2 +- .../service_api/app/conversation.py | 20 +- .../service_api/dataset/dataset.py | 84 ++++ .../service_api/dataset/document.py | 390 +++++++++++++++--- api/controllers/service_api/dataset/error.py | 69 +++- .../service_api/dataset/segment.py | 59 +++ api/controllers/service_api/wraps.py | 29 +- api/controllers/web/conversation.py | 18 +- api/core/data_loader/loader/notion.py | 1 + .../keyword_table_index.py | 17 + api/fields/__init__.py | 0 api/fields/app_fields.py | 138 +++++++ api/fields/conversation_fields.py | 182 ++++++++ api/fields/data_source_fields.py | 65 +++ api/fields/dataset_fields.py | 43 ++ api/fields/document_fields.py | 76 ++++ api/fields/file_fields.py | 18 + api/fields/hit_testing_fields.py | 41 ++ api/fields/installed_app_fields.py | 25 ++ api/fields/message_fields.py | 43 ++ api/fields/segment_fields.py | 32 ++ ...2e9819ca5b28_add_tenant_id_in_api_token.py | 36 ++ api/models/model.py | 5 +- api/services/dataset_service.py | 68 ++- api/services/errors/__init__.py | 2 +- api/services/errors/file.py | 8 + api/services/file_service.py | 123 ++++++ api/services/vector_service.py | 26 ++ 43 files changed, 1632 insertions(+), 902 deletions(-) create mode 100644 api/controllers/service_api/dataset/dataset.py create mode 100644 api/controllers/service_api/dataset/segment.py create mode 100644 api/fields/__init__.py create mode 100644 api/fields/app_fields.py create mode 100644 api/fields/conversation_fields.py create mode 100644 api/fields/data_source_fields.py create mode 100644 api/fields/dataset_fields.py create mode 100644 api/fields/document_fields.py create mode 100644 api/fields/file_fields.py create mode 100644 api/fields/hit_testing_fields.py create mode 100644 api/fields/installed_app_fields.py create mode 100644 api/fields/message_fields.py create mode 100644 api/fields/segment_fields.py create mode 100644 api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py create mode 100644 api/services/file_service.py diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index c63a21aaeae6c0..211ada6267c74e 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -81,6 +81,7 @@ def post(self, resource_id): key = ApiToken.generate_api_key(self.token_prefix, 24) api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) + api_token.tenant_id = current_user.current_tenant_id api_token.token = key api_token.type = self.resource_type db.session.add(api_token) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 7c88a14b77fdc6..0acbc3ae47d947 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -19,41 +19,13 @@ from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.models.entity.model_params import ModelType from events.app_event import app_was_created, app_was_deleted +from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \ + app_detail_fields_with_site from libs.helper import TimestampField from extensions.ext_database import db from models.model import App, AppModelConfig, Site from services.app_model_config_service import AppModelConfigService -model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), - 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'dataset_query_variable': fields.String, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), -} - -app_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_demo': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), - 'created_at': TimestampField -} - def _get_app(app_id, tenant_id): app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() @@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id): class AppListApi(Resource): - prompt_config_fields = { - 'prompt_template': fields.String, - } - - model_config_partial_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, - } - - app_partial_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'is_demo': fields.Boolean, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), - 'created_at': TimestampField - } - - app_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') - } @setup_required @login_required @@ -238,18 +181,6 @@ def post(self): class AppTemplateApi(Resource): - template_fields = { - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'mode': fields.String, - 'model_config': fields.Nested(model_config_fields), - } - - template_list_fields = { - 'data': fields.List(fields.Nested(template_fields)), - } @setup_required @login_required @@ -268,38 +199,6 @@ def get(self): class AppApi(Resource): - site_fields = { - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'app_base_url': fields.String, - } - - app_detail_fields_with_site = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_demo': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), - 'site': fields.Nested(site_fields), - 'api_base_url': fields.String, - 'created_at': TimestampField - } @setup_required @login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5b24bb2882a00a..bbc95c61f750db 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -13,107 +13,14 @@ from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \ + conversation_message_detail_fields, conversation_with_summary_pagination_fields from libs.helper import TimestampField, datetime_string, uuid_value from extensions.ext_database import db from models.model import Message, MessageAnnotation, Conversation -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - -feedback_fields = { - 'rating': fields.String, - 'content': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account': fields.Nested(account_fields, allow_null=True), -} - -annotation_fields = { - 'content': fields.String, - 'account': fields.Nested(account_fields, allow_null=True), - 'created_at': TimestampField -} - -message_detail_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'message': fields.Raw, - 'message_tokens': fields.Integer, - 'answer': fields.String, - 'answer_tokens': fields.Integer, - 'provider_response_latency': fields.Float, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'feedbacks': fields.List(fields.Nested(feedback_fields)), - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'created_at': TimestampField -} - -feedback_stat_fields = { - 'like': fields.Integer, - 'dislike': fields.Integer -} - -model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'model': fields.Raw, - 'user_input_form': fields.Raw, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw, -} - class CompletionConversationApi(Resource): - class MessageTextField(fields.Raw): - def format(self, value): - return value[0]['text'] if value else '' - - simple_configs_fields = { - 'prompt_template': fields.String, - } - - simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, - } - - simple_message_detail_fields = { - 'inputs': fields.Raw, - 'query': fields.String, - 'message': MessageTextField, - 'answer': fields.String, - } - - conversation_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String(), - 'from_account_id': fields.String, - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'model_config': fields.Nested(simple_model_config_fields), - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields), - 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') - } - - conversation_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_fields), attribute='items') - } @setup_required @login_required @@ -191,21 +98,11 @@ def get(self, app_id): class CompletionConversationDetailApi(Resource): - conversation_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'model_config': fields.Nested(model_config_fields), - 'message': fields.Nested(message_detail_fields, attribute='first_message'), - } @setup_required @login_required @account_initialization_required - @marshal_with(conversation_detail_fields) + @marshal_with(conversation_message_detail_fields) def get(self, app_id, conversation_id): app_id = str(app_id) conversation_id = str(conversation_id) @@ -234,44 +131,11 @@ def delete(self, app_id, conversation_id): class ChatConversationApi(Resource): - simple_configs_fields = { - 'prompt_template': fields.String, - } - - simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, - } - - conversation_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String, - 'from_account_id': fields.String, - 'summary': fields.String(attribute='summary_or_query'), - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'model_config': fields.Nested(simple_model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) - } - - conversation_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_fields), attribute='items') - } @setup_required @login_required @account_initialization_required - @marshal_with(conversation_pagination_fields) + @marshal_with(conversation_with_summary_pagination_fields) def get(self, app_id): app_id = str(app_id) @@ -356,19 +220,6 @@ def get(self, app_id): class ChatConversationDetailApi(Resource): - conversation_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'model_config': fields.Nested(model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) - } @setup_required @login_required diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 9745e2f41a2e03..d634692ed060d9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -17,6 +17,7 @@ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from core.login.login import login_required +from fields.conversation_fields import message_detail_fields from libs.helper import uuid_value, TimestampField from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db @@ -27,44 +28,6 @@ from services.errors.message import MessageNotExistsError from services.message_service import MessageService -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - -feedback_fields = { - 'rating': fields.String, - 'content': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account': fields.Nested(account_fields, allow_null=True), -} - -annotation_fields = { - 'content': fields.String, - 'account': fields.Nested(account_fields, allow_null=True), - 'created_at': TimestampField -} - -message_detail_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'message': fields.Raw, - 'message_tokens': fields.Integer, - 'answer': fields.String, - 'answer_tokens': fields.Integer, - 'provider_response_latency': fields.Float, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'feedbacks': fields.List(fields.Nested(feedback_fields)), - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'created_at': TimestampField -} - class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 128df38110c210..a796edf8d7f811 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -8,26 +8,11 @@ from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from fields.app_fields import app_site_fields from libs.helper import supported_language from extensions.ext_database import db from models.model import Site -app_site_fields = { - 'app_id': fields.String, - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean -} - def parse_app_site_args(): parser = reqparse.RequestParser() diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 87aa26304079b8..532f47af75e66c 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -14,6 +14,7 @@ from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner from extensions.ext_database import db +from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields from libs.helper import TimestampField from models.dataset import Document from models.source import DataSourceBinding @@ -24,37 +25,6 @@ class DataSourceApi(Resource): - integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String - } - integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'parent_id': fields.String, - 'type': fields.String - } - integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)), - 'total': fields.Integer - } - integrate_fields = { - 'id': fields.String, - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'disabled': fields.Boolean, - 'link': fields.String, - 'source_info': fields.Nested(integrate_workspace_fields) - } - integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), - } @setup_required @login_required @@ -131,28 +101,6 @@ def patch(self, binding_id, action): class DataSourceNotionListApi(Resource): - integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String - } - integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'is_bound': fields.Boolean, - 'parent_id': fields.String, - 'type': fields.String - } - integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)) - } - integrate_notion_info_list_fields = { - 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), - } @setup_required @login_required diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 6c73cb255ac277..dfaf5ff2d1474b 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,6 +1,9 @@ # -*- coding:utf-8 -*- -from flask import request +import flask_restful +from flask import request, current_app from flask_login import current_user + +from controllers.console.apikey import api_key_list, api_key_fields from core.login.login import login_required from flask_restful import Resource, reqparse, fields, marshal, marshal_with from werkzeug.exceptions import NotFound, Forbidden @@ -12,45 +15,16 @@ from controllers.console.wraps import account_initialization_required from core.indexing_runner import IndexingRunner from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError -from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.model_params import ModelType -from libs.helper import TimestampField +from fields.app_fields import related_app_list +from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields +from fields.document_fields import document_status_fields from extensions.ext_database import db from models.dataset import DocumentSegment, Document -from models.model import UploadFile +from models.model import UploadFile, ApiToken from services.dataset_service import DatasetService, DocumentService from services.provider_service import ProviderService -dataset_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'provider': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'app_count': fields.Integer, - 'document_count': fields.Integer, - 'word_count': fields.Integer, - 'created_by': fields.String, - 'created_at': TimestampField, - 'updated_by': fields.String, - 'updated_at': TimestampField, - 'embedding_model': fields.String, - 'embedding_model_provider': fields.String, - 'embedding_available': fields.Boolean -} - -dataset_query_detail_fields = { - "id": fields.String, - "content": fields.String, - "source": fields.String, - "source_app_id": fields.String, - "created_by_role": fields.String, - "created_by": fields.String, - "created_at": TimestampField -} - def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: @@ -82,7 +56,8 @@ def get(self): # check embedding setting provider_service = ProviderService() - valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) + valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, + ModelType.EMBEDDINGS.value) # if len(valid_model_list) == 0: # raise ProviderNotInitializeError( # f"No Embedding Model available. Please configure a valid provider " @@ -157,7 +132,8 @@ def get(self, dataset_id): # check embedding setting provider_service = ProviderService() # get valid model list - valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) + valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, + ModelType.EMBEDDINGS.value) model_names = [] for valid_model in valid_model_list: model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") @@ -271,7 +247,8 @@ def post(self): parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, + location='json') args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) @@ -320,18 +297,6 @@ def post(self): class DatasetRelatedAppListApi(Resource): - app_detail_kernel_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - } - - related_app_list = { - 'data': fields.List(fields.Nested(app_detail_kernel_fields)), - 'total': fields.Integer, - } @setup_required @login_required @@ -363,24 +328,6 @@ def get(self, dataset_id): class DatasetIndexingStatusApi(Resource): - document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, - } - - document_status_fields_list = { - 'data': fields.List(fields.Nested(document_status_fields)) - } @setup_required @login_required @@ -400,16 +347,97 @@ def get(self, dataset_id): DocumentSegment.status != 're_segment').count() document.completed_segments = completed_segments document.total_segments = total_segments - documents_status.append(marshal(document, self.document_status_fields)) + documents_status.append(marshal(document, document_status_fields)) data = { 'data': documents_status } return data +class DatasetApiKeyApi(Resource): + max_keys = 10 + token_prefix = 'dataset-' + resource_type = 'dataset' + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_key_list) + def get(self): + keys = db.session.query(ApiToken). \ + filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ + all() + return {"items": keys} + + @setup_required + @login_required + @account_initialization_required + @marshal_with(api_key_fields) + def post(self): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + current_key_count = db.session.query(ApiToken). \ + filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ + count() + + if current_key_count >= self.max_keys: + flask_restful.abort( + 400, + message=f"Cannot create more than {self.max_keys} API keys for this resource type.", + code='max_keys_exceeded' + ) + + key = ApiToken.generate_api_key(self.token_prefix, 24) + api_token = ApiToken() + api_token.tenant_id = current_user.current_tenant_id + api_token.token = key + api_token.type = self.resource_type + db.session.add(api_token) + db.session.commit() + return api_token, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, api_key_id): + api_key_id = str(api_key_id) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + key = db.session.query(ApiToken). \ + filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, + ApiToken.id == api_key_id). \ + first() + + if key is None: + flask_restful.abort(404, message='API key not found') + + db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() + db.session.commit() + + return {'result': 'success'}, 204 + + +class DatasetApiBaseUrlApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + return { + 'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] + else request.host_url.rstrip('/')) + '/v1' + } + + api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetApi, '/datasets/') api.add_resource(DatasetQueryApi, '/datasets//queries') api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') +api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') +api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 9b212f2e90c6d4..5cdc79efd5b927 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -23,6 +23,8 @@ LLMBadRequestError from core.model_providers.model_factory import ModelFactory from extensions.ext_redis import redis_client +from fields.document_fields import document_with_segments_fields, document_fields, \ + dataset_and_document_fields, document_status_fields from libs.helper import TimestampField from extensions.ext_database import db from models.dataset import DatasetProcessRule, Dataset @@ -32,64 +34,6 @@ from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task -dataset_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, -} - -document_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'doc_form': fields.String, -} - -document_with_segments_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer -} - class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: @@ -303,11 +247,6 @@ def post(self, dataset_id): class DatasetInitApi(Resource): - dataset_and_document_fields = { - 'dataset': fields.Nested(dataset_fields), - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String - } @setup_required @login_required @@ -504,24 +443,6 @@ def get(self, dataset_id, batch): class DocumentBatchIndexingStatusApi(DocumentResource): - document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, - } - - document_status_fields_list = { - 'data': fields.List(fields.Nested(document_status_fields)) - } @setup_required @login_required @@ -541,7 +462,7 @@ def get(self, dataset_id, batch): document.total_segments = total_segments if document.is_paused: document.indexing_status = 'paused' - documents_status.append(marshal(document, self.document_status_fields)) + documents_status.append(marshal(document, document_status_fields)) data = { 'data': documents_status } @@ -549,20 +470,6 @@ def get(self, dataset_id, batch): class DocumentIndexingStatusApi(DocumentResource): - document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, - } @setup_required @login_required @@ -586,7 +493,7 @@ def get(self, dataset_id, document_id): document.total_segments = total_segments if document.is_paused: document.indexing_status = 'paused' - return marshal(document, self.document_status_fields) + return marshal(document, document_status_fields) class DocumentDetailApi(DocumentResource): diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 7dac492d1d36d7..8c164e2aa7dfdf 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,7 +3,7 @@ from datetime import datetime from flask import request from flask_login import current_user -from flask_restful import Resource, reqparse, fields, marshal +from flask_restful import Resource, reqparse, marshal from werkzeug.exceptions import NotFound, Forbidden import services @@ -17,6 +17,7 @@ from core.login.login import login_required from extensions.ext_database import db from extensions.ext_redis import redis_client +from fields.segment_fields import segment_fields from models.dataset import DocumentSegment from libs.helper import TimestampField @@ -26,36 +27,6 @@ from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task import pandas as pd -segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField -} - -segment_list_response = { - 'data': fields.List(fields.Nested(segment_fields)), - 'has_more': fields.Boolean, - 'limit': fields.Integer -} - class DatasetDocumentSegmentListApi(Resource): @setup_required diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index c68083df6a1245..52331ff391f6df 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,28 +1,19 @@ -import datetime -import hashlib -import tempfile -import chardet -import time -import uuid -from pathlib import Path - from cachetools import TTLCache from flask import request, current_app -from flask_login import current_user + +import services from core.login.login import login_required from flask_restful import Resource, marshal_with, fields -from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ UnsupportedFileTypeError + from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.data_loader.file_extractor import FileExtractor -from extensions.ext_storage import storage -from libs.helper import TimestampField -from extensions.ext_database import db -from models.model import UploadFile +from fields.file_fields import upload_config_fields, file_fields + +from services.file_service import FileService cache = TTLCache(maxsize=None, ttl=30) @@ -31,10 +22,6 @@ class FileApi(Resource): - upload_config_fields = { - 'file_size_limit': fields.Integer, - 'batch_count_limit': fields.Integer - } @setup_required @login_required @@ -48,16 +35,6 @@ def get(self): 'batch_count_limit': batch_count_limit }, 200 - file_fields = { - 'id': fields.String, - 'name': fields.String, - 'size': fields.Integer, - 'extension': fields.String, - 'mime_type': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - } - @setup_required @login_required @account_initialization_required @@ -73,45 +50,13 @@ def post(self): if len(request.files) > 1: raise TooManyFilesError() - - file_content = file.read() - file_size = len(file_content) - - file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 - if file_size > file_size_limit: - message = "({file_size} > {file_size_limit})" - raise FileTooLargeError(message) - - extension = file.filename.split('.')[-1] - if extension.lower() not in ALLOWED_EXTENSIONS: + try: + upload_file = FileService.upload_file(file) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - # user uuid as file name - file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension - - # save file to storage - storage.save(file_key, file_content) - - # save file to db - config = current_app.config - upload_file = UploadFile( - tenant_id=current_user.current_tenant_id, - storage_type=config['STORAGE_TYPE'], - key=file_key, - name=file.filename, - size=file_size, - extension=extension, - mime_type=file.mimetype, - created_by=current_user.id, - created_at=datetime.datetime.utcnow(), - used=False, - hash=hashlib.sha3_256(file_content).hexdigest() - ) - - db.session.add(upload_file) - db.session.commit() - return upload_file, 201 @@ -121,26 +66,7 @@ class FilePreviewApi(Resource): @account_initialization_required def get(self, file_id): file_id = str(file_id) - - key = file_id + request.path - cached_response = cache.get(key) - if cached_response and time.time() - cached_response['timestamp'] < cache.ttl: - return cached_response['response'] - - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() - - if not upload_file: - raise NotFound("File not found") - - # extract text from file - extension = upload_file.extension - if extension.lower() not in ALLOWED_EXTENSIONS: - raise UnsupportedFileTypeError() - - text = FileExtractor.load(upload_file, return_text=True) - text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + text = FileService.get_file_preview(file_id) return {'content': text} diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index ac375207901c49..a2a703ca5f4b7f 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,7 +2,7 @@ from flask_login import current_user from core.login.login import login_required -from flask_restful import Resource, reqparse, marshal, fields +from flask_restful import Resource, reqparse, marshal from werkzeug.exceptions import InternalServerError, NotFound, Forbidden import services @@ -14,48 +14,10 @@ from controllers.console.wraps import account_initialization_required from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ LLMBadRequestError -from libs.helper import TimestampField +from fields.hit_testing_fields import hit_testing_record_fields from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService -document_fields = { - 'id': fields.String, - 'data_source_type': fields.String, - 'name': fields.String, - 'doc_type': fields.String, -} - -segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'document': fields.Nested(document_fields), -} - -hit_testing_record_fields = { - 'segment': fields.Nested(segment_fields), - 'score': fields.Float, - 'tsne_position': fields.Raw -} - class HitTestingApi(Resource): diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 6381070288769b..aa56437f33166b 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -7,26 +7,12 @@ from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource +from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import TimestampField, uuid_value from services.conversation_service import ConversationService from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.web_conversation_service import WebConversationService -conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField -} - -conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_fields)) -} - class ConversationListApi(InstalledAppResource): @@ -76,7 +62,7 @@ def delete(self, installed_app, c_id): class ConversationRenameApi(InstalledAppResource): - @marshal_with(conversation_fields) + @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app if app_model.mode != 'chat': diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index b6592e94de8d11..0518de12ab17f2 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -11,32 +11,11 @@ from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from fields.installed_app_fields import installed_app_list_fields from libs.helper import TimestampField from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService -app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String -} - -installed_app_fields = { - 'id': fields.String, - 'app': fields.Nested(app_fields), - 'app_owner_tenant_id': fields.String, - 'is_pinned': fields.Boolean, - 'last_used_at': TimestampField, - 'editable': fields.Boolean, - 'uninstallable': fields.Boolean, -} - -installed_app_list_fields = { - 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} - class InstalledAppsListApi(Resource): @login_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 0349c23ef3ae95..efc25aaa92ec51 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -17,6 +17,7 @@ from controllers.console.explore.wraps import InstalledAppResource from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError +from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value, TimestampField from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError @@ -26,45 +27,6 @@ class MessageListApi(InstalledAppResource): - feedback_fields = { - 'rating': fields.String - } - - retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField - } - - message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField - } - - message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) - } @marshal_with(message_infinite_scroll_pagination_fields) def get(self, installed_app): diff --git a/api/controllers/console/universal_chat/conversation.py b/api/controllers/console/universal_chat/conversation.py index 207d622ce970d2..c0782cb81a91ac 100644 --- a/api/controllers/console/universal_chat/conversation.py +++ b/api/controllers/console/universal_chat/conversation.py @@ -6,31 +6,17 @@ from controllers.console import api from controllers.console.universal_chat.wraps import UniversalChatResource +from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \ + conversation_with_model_config_fields from libs.helper import TimestampField, uuid_value from services.conversation_service import ConversationService from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.web_conversation_service import WebConversationService -conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField, - 'model_config': fields.Raw, -} - -conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_fields)) -} - class UniversalChatConversationListApi(UniversalChatResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) + @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields) def get(self, universal_app): app_model = universal_app @@ -73,7 +59,7 @@ def delete(self, universal_app, c_id): class UniversalChatConversationRenameApi(UniversalChatResource): - @marshal_with(conversation_fields) + @marshal_with(conversation_with_model_config_fields) def post(self, universal_app, c_id): app_model = universal_app conversation_id = str(c_id) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index ce5b44fc22bc44..0c3ec30072a72d 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -9,4 +9,4 @@ from .app import completion, app, conversation, message, audio -from .dataset import document +from .dataset import document, segment, dataset diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 26a95bd6be441d..2fabfdd1338561 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -8,25 +8,11 @@ from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource +from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import TimestampField, uuid_value import services from services.conversation_service import ConversationService -conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField -} - -conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_fields)) -} - class ConversationApi(AppApiResource): @@ -50,7 +36,7 @@ def get(self, app_model, end_user): raise NotFound("Last Conversation Not Exists.") class ConversationDetailApi(AppApiResource): - @marshal_with(conversation_fields) + @marshal_with(simple_conversation_fields) def delete(self, app_model, end_user, c_id): if app_model.mode != 'chat': raise NotChatAppError() @@ -70,7 +56,7 @@ def delete(self, app_model, end_user, c_id): class ConversationRenameApi(AppApiResource): - @marshal_with(conversation_fields) + @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): if app_model.mode != 'chat': raise NotChatAppError() diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py new file mode 100644 index 00000000000000..29f9a5613ce1cf --- /dev/null +++ b/api/controllers/service_api/dataset/dataset.py @@ -0,0 +1,84 @@ +from flask import request +from flask_restful import reqparse, marshal +import services.dataset_service +from controllers.service_api import api +from controllers.service_api.dataset.error import DatasetNameDuplicateError +from controllers.service_api.wraps import DatasetApiResource +from core.login.login import current_user +from core.model_providers.models.entity.model_params import ModelType +from extensions.ext_database import db +from fields.dataset_fields import dataset_detail_fields +from models.account import Account, TenantAccountJoin +from models.dataset import Dataset +from services.dataset_service import DatasetService +from services.provider_service import ProviderService + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError('Name must be between 1 to 40 characters.') + return name + + +class DatasetApi(DatasetApiResource): + """Resource for get datasets.""" + + def get(self, tenant_id): + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + provider = request.args.get('provider', default="vendor") + datasets, total = DatasetService.get_datasets(page, limit, provider, + tenant_id, current_user) + # check embedding setting + provider_service = ProviderService() + valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, + ModelType.EMBEDDINGS.value) + model_names = [] + for valid_model in valid_model_list: + model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") + data = marshal(datasets, dataset_detail_fields) + for item in data: + if item['indexing_technique'] == 'high_quality': + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if item_model in model_names: + item['embedding_available'] = True + else: + item['embedding_available'] = False + else: + item['embedding_available'] = True + response = { + 'data': data, + 'has_more': len(datasets) == limit, + 'limit': limit, + 'total': total, + 'page': page + } + return response, 200 + + """Resource for datasets.""" + + def post(self, tenant_id): + parser = reqparse.RequestParser() + parser.add_argument('name', nullable=False, required=True, + help='type is required. Name must be between 1 to 40 characters.', + type=_validate_name) + parser.add_argument('indexing_technique', type=str, location='json', + choices=('high_quality', 'economy'), + help='Invalid indexing technique.') + args = parser.parse_args() + + try: + dataset = DatasetService.create_empty_dataset( + tenant_id=tenant_id, + name=args['name'], + indexing_technique=args['indexing_technique'], + account=current_user + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return marshal(dataset, dataset_detail_fields), 200 + + +api.add_resource(DatasetApi, '/datasets') + diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 7cb4d49897a566..a8e1c4ab732457 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,114 +1,291 @@ import datetime +import json import uuid -from flask import current_app -from flask_restful import reqparse +from flask import current_app, request +from flask_restful import reqparse, marshal +from sqlalchemy import desc from werkzeug.exceptions import NotFound import services.dataset_service from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ - DatasetNotInitedError + NoFileUploadedError, TooManyFilesError from controllers.service_api.wraps import DatasetApiResource +from core.login.login import current_user from core.model_providers.error import ProviderTokenNotInitError from extensions.ext_database import db from extensions.ext_storage import storage +from fields.document_fields import document_fields, document_status_fields +from models.dataset import Dataset, Document, DocumentSegment from models.model import UploadFile from services.dataset_service import DocumentService +from services.file_service import FileService -class DocumentListApi(DatasetApiResource): +class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - def post(self, dataset): - """Create document.""" + def post(self, tenant_id, dataset_id): + """Create document by text.""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('doc_type', type=str, location='json') - parser.add_argument('doc_metadata', type=dict, location='json') + parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') + parser.add_argument('original_document_id', type=str, required=False, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, + location='json') + parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, + location='json') + args = parser.parse_args() + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + raise ValueError('Dataset is not exist.') + + if not dataset.indexing_technique and not args['indexing_technique']: + raise ValueError('indexing_technique is required.') + + upload_file = FileService.upload_text(args.get('text'), args.get('name')) + data_source = { + 'type': 'upload_file', + 'info_list': { + 'data_source_type': 'upload_file', + 'file_info_list': { + 'file_ids': [upload_file.id] + } + } + } + args['data_source'] = data_source + # validate args + DocumentService.document_create_args_validate(args) + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + document_data=args, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, + created_from='api' + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = { + 'document': marshal(document, document_fields), + 'batch': batch + } + return documents_and_batch_fields, 200 + + +class DocumentUpdateByTextApi(DatasetApiResource): + """Resource for update documents.""" + + def post(self, tenant_id, dataset_id, document_id): + """Update document by text.""" + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=False, nullable=True, location='json') + parser.add_argument('text', type=str, required=False, nullable=True, location='json') + parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, + location='json') args = parser.parse_args() + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() - if not dataset.indexing_technique: - raise DatasetNotInitedError("Dataset indexing technique must be set.") - - doc_type = args.get('doc_type') - doc_metadata = args.get('doc_metadata') - - if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: - raise ValueError('Invalid doc_type.') - - # user uuid as file name - file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt' - - # save file to storage - storage.save(file_key, args.get('text')) - - # save file to db - config = current_app.config - upload_file = UploadFile( - tenant_id=dataset.tenant_id, - storage_type=config['STORAGE_TYPE'], - key=file_key, - name=args.get('name') + '.txt', - size=len(args.get('text')), - extension='txt', - mime_type='text/plain', - created_by=dataset.created_by, - created_at=datetime.datetime.utcnow(), - used=True, - used_by=dataset.created_by, - used_at=datetime.datetime.utcnow() - ) - - db.session.add(upload_file) - db.session.commit() - - document_data = { - 'data_source': { + if not dataset: + raise ValueError('Dataset is not exist.') + + if args['text']: + upload_file = FileService.upload_text(args.get('text'), args.get('name')) + data_source = { 'type': 'upload_file', - 'info': [ - { - 'upload_file_id': upload_file.id + 'info_list': { + 'data_source_type': 'upload_file', + 'file_info_list': { + 'file_ids': [upload_file.id] } - ] + } + } + args['data_source'] = data_source + # validate args + args['original_document_id'] = str(document_id) + DocumentService.document_create_args_validate(args) + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + document_data=args, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, + created_from='api' + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = { + 'document': marshal(document, document_fields), + 'batch': batch + } + return documents_and_batch_fields, 200 + + +class DocumentAddByFileApi(DatasetApiResource): + """Resource for documents.""" + def post(self, tenant_id, dataset_id): + """Create document by upload file.""" + args = {} + if 'data' in request.form: + args = json.loads(request.form['data']) + if 'doc_form' not in args: + args['doc_form'] = 'text_model' + if 'doc_language' not in args: + args['doc_language'] = 'English' + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + raise ValueError('Dataset is not exist.') + if not dataset.indexing_technique and not args['indexing_technique']: + raise ValueError('indexing_technique is required.') + + # save file info + file = request.files['file'] + # check file + if 'file' not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + upload_file = FileService.upload_file(file) + data_source = { + 'type': 'upload_file', + 'info_list': { + 'file_info_list': { + 'file_ids': [upload_file.id] + } } } + args['data_source'] = data_source + # validate args + DocumentService.document_create_args_validate(args) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=document_data, + document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule, + dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, created_from='api' ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - if doc_type and doc_metadata: - metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + documents_and_batch_fields = { + 'document': marshal(document, document_fields), + 'batch': batch + } + return documents_and_batch_fields, 200 - document.doc_metadata = {} - for key, value_type in metadata_schema.items(): - value = doc_metadata.get(key) - if value is not None and isinstance(value, value_type): - document.doc_metadata[key] = value +class DocumentUpdateByFileApi(DatasetApiResource): + """Resource for update documents.""" - document.doc_type = doc_type - document.updated_at = datetime.datetime.utcnow() - db.session.commit() + def post(self, tenant_id, dataset_id, document_id): + """Update document by upload file.""" + args = {} + if 'data' in request.form: + args = json.loads(request.form['data']) + if 'doc_form' not in args: + args['doc_form'] = 'text_model' + if 'doc_language' not in args: + args['doc_language'] = 'English' - return {'id': document.id} + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + raise ValueError('Dataset is not exist.') + if 'file' in request.files: + # save file info + file = request.files['file'] + + + if len(request.files) > 1: + raise TooManyFilesError() + + upload_file = FileService.upload_file(file) + data_source = { + 'type': 'upload_file', + 'info_list': { + 'file_info_list': { + 'file_ids': [upload_file.id] + } + } + } + args['data_source'] = data_source + # validate args + args['original_document_id'] = str(document_id) + DocumentService.document_create_args_validate(args) + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + document_data=args, + account=dataset.created_by_account, + dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, + created_from='api' + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + documents_and_batch_fields = { + 'document': marshal(document, document_fields), + 'batch': batch + } + return documents_and_batch_fields, 200 -class DocumentApi(DatasetApiResource): - def delete(self, dataset, document_id): +class DocumentDeleteApi(DatasetApiResource): + def delete(self, tenant_id, dataset_id, document_id): """Delete document.""" document_id = str(document_id) + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + + # get dataset info + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + raise ValueError('Dataset is not exist.') document = DocumentService.get_document(dataset.id, document_id) @@ -126,8 +303,85 @@ def delete(self, dataset, document_id): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError('Cannot delete document during indexing.') - return {'result': 'success'}, 204 + return {'result': 'success'}, 200 + + +class DocumentListApi(DatasetApiResource): + def get(self, tenant_id, dataset_id): + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + search = request.args.get('keyword', default=None, type=str) + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + if not dataset: + raise NotFound('Dataset not found.') + + query = Document.query.filter_by( + dataset_id=str(dataset_id), tenant_id=tenant_id) + + if search: + search = f'%{search}%' + query = query.filter(Document.name.like(search)) + + query = query.order_by(desc(Document.created_at)) + + paginated_documents = query.paginate( + page=page, per_page=limit, max_per_page=100, error_out=False) + documents = paginated_documents.items + + response = { + 'data': marshal(documents, document_fields), + 'has_more': len(documents) == limit, + 'limit': limit, + 'total': paginated_documents.total, + 'page': page + } + + return response + + +class DocumentIndexingStatusApi(DatasetApiResource): + def get(self, tenant_id, dataset_id, batch): + dataset_id = str(dataset_id) + batch = str(batch) + tenant_id = str(tenant_id) + # get dataset + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + if not dataset: + raise NotFound('Dataset not found.') + # get documents + documents = DocumentService.get_batch_documents(dataset_id, batch) + if not documents: + raise NotFound('Documents not found.') + documents_status = [] + for document in documents: + completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), + DocumentSegment.status != 're_segment').count() + document.completed_segments = completed_segments + document.total_segments = total_segments + if document.is_paused: + document.indexing_status = 'paused' + documents_status.append(marshal(document, document_status_fields)) + data = { + 'data': documents_status + } + return data -api.add_resource(DocumentListApi, '/documents') -api.add_resource(DocumentApi, '/documents/') +api.add_resource(DocumentAddByTextApi, '/datasets//document/create_by_text') +api.add_resource(DocumentAddByFileApi, '/datasets//document/create_by_file') +api.add_resource(DocumentUpdateByTextApi, '/datasets//documents//update_by_text') +api.add_resource(DocumentUpdateByFileApi, '/datasets//documents//update_by_file') +api.add_resource(DocumentDeleteApi, '/datasets//documents/') +api.add_resource(DocumentListApi, '/datasets//documents') +api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index 2131fe0bacdd52..29142b80e627f3 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -1,20 +1,73 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException +class NoFileUploadedError(BaseHTTPException): + error_code = 'no_file_uploaded' + description = "Please upload your file." + code = 400 + + +class TooManyFilesError(BaseHTTPException): + error_code = 'too_many_files' + description = "Only one file is allowed." + code = 400 + + +class FileTooLargeError(BaseHTTPException): + error_code = 'file_too_large' + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = 'unsupported_file_type' + description = "File type not allowed." + code = 415 + + +class HighQualityDatasetOnlyError(BaseHTTPException): + error_code = 'high_quality_dataset_only' + description = "Current operation only supports 'high-quality' datasets." + code = 400 + + +class DatasetNotInitializedError(BaseHTTPException): + error_code = 'dataset_not_initialized' + description = "The dataset is still being initialized or indexing. Please wait a moment." + code = 400 + + class ArchivedDocumentImmutableError(BaseHTTPException): error_code = 'archived_document_immutable' - description = "Cannot operate when document was archived." + description = "The archived document is not editable." code = 403 +class DatasetNameDuplicateError(BaseHTTPException): + error_code = 'dataset_name_duplicate' + description = "The dataset name already exists. Please modify your dataset name." + code = 409 + + +class InvalidActionError(BaseHTTPException): + error_code = 'invalid_action' + description = "Invalid action." + code = 400 + + +class DocumentAlreadyFinishedError(BaseHTTPException): + error_code = 'document_already_finished' + description = "The document has been processed. Please refresh the page or go to the document details." + code = 400 + + class DocumentIndexingError(BaseHTTPException): error_code = 'document_indexing' - description = "Cannot operate document during indexing." - code = 403 + description = "The document is being processed and cannot be edited." + code = 400 -class DatasetNotInitedError(BaseHTTPException): - error_code = 'dataset_not_inited' - description = "The dataset is still being initialized or indexing. Please wait a moment." - code = 403 +class InvalidMetadataError(BaseHTTPException): + error_code = 'invalid_metadata' + description = "The metadata content is incorrect. Please check and verify." + code = 400 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py new file mode 100644 index 00000000000000..70c3e73d917011 --- /dev/null +++ b/api/controllers/service_api/dataset/segment.py @@ -0,0 +1,59 @@ +from flask_login import current_user +from flask_restful import reqparse, marshal +from werkzeug.exceptions import NotFound + +from controllers.service_api import api +from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.wraps import DatasetApiResource +from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError +from core.model_providers.model_factory import ModelFactory +from extensions.ext_database import db +from fields.segment_fields import segment_fields +from models.dataset import Dataset +from services.dataset_service import DocumentService, SegmentService + + +class SegmentApi(DatasetApiResource): + """Resource for segments.""" + def post(self, tenant_id, dataset_id, document_id): + """Create single segment.""" + # check dataset + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound('Document not found.') + # check embedding model setting + if dataset.indexing_technique == 'high_quality': + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + # validate args + parser = reqparse.RequestParser() + parser.add_argument('segments', type=list, required=False, nullable=True, location='json') + args = parser.parse_args() + for args_item in args['segments']: + SegmentService.segment_create_args_validate(args_item, document) + segments = SegmentService.multi_create_segment(args['segments'], document, dataset) + return { + 'data': marshal(segments, segment_fields), + 'doc_form': document.doc_form + }, 200 + + +api.add_resource(SegmentApi, '/datasets//documents//segments') diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index b76928daa6b749..44d051707b946e 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -2,11 +2,14 @@ from datetime import datetime from functools import wraps -from flask import request +from flask import request, current_app +from flask_login import user_logged_in from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized +from core.login.login import _get_user from extensions.ext_database import db +from models.account import Tenant, TenantAccountJoin, Account from models.dataset import Dataset from models.model import ApiToken, App @@ -43,12 +46,24 @@ def decorator(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token('dataset') - - dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first() - if not dataset: - raise NotFound() - - return view(dataset, *args, **kwargs) + tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ + .filter(Tenant.id == api_token.tenant_id) \ + .filter(TenantAccountJoin.tenant_id == Tenant.id) \ + .filter(TenantAccountJoin.role == 'owner') \ + .one_or_none() + if tenant_account_join: + tenant, ta = tenant_account_join + account = Account.query.filter_by(id=ta.account_id).first() + # Login admin + if account: + account.current_tenant = tenant + current_app.login_manager._update_request_context_with_user(account) + user_logged_in.send(current_app._get_current_object(), user=_get_user()) + else: + raise Unauthorized("Tenant owner account is not exist.") + else: + raise Unauthorized("Tenant is not exist.") + return view(api_token.tenant_id, *args, **kwargs) return decorated if view: diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index d08e22c16b08ef..ce089ca39572f5 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -6,26 +6,12 @@ from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource +from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import TimestampField, uuid_value from services.conversation_service import ConversationService from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.web_conversation_service import WebConversationService -conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField -} - -conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_fields)) -} - class ConversationListApi(WebApiResource): @@ -73,7 +59,7 @@ def delete(self, app_model, end_user, c_id): class ConversationRenameApi(WebApiResource): - @marshal_with(conversation_fields) + @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): if app_model.mode != 'chat': raise NotChatAppError() diff --git a/api/core/data_loader/loader/notion.py b/api/core/data_loader/loader/notion.py index e54266f4ce292e..2162df83e15b78 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -16,6 +16,7 @@ BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" SEARCH_URL = "https://api.notion.com/v1/search" + RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 38f68c7a5057bc..b315de6191ec58 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -246,11 +246,28 @@ def create_segment_keywords(self, node_id: str, keywords: List[str]): keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) self._save_dataset_keyword_table(keyword_table) + def multi_create_segment_keywords(self, pre_segment_data_list: list): + keyword_table_handler = JiebaKeywordTableHandler() + keyword_table = self._get_dataset_keyword_table() + for pre_segment_data in pre_segment_data_list: + segment = pre_segment_data['segment'] + if pre_segment_data['keywords']: + segment.keywords = pre_segment_data['keywords'] + keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, + pre_segment_data['keywords']) + else: + keywords = keyword_table_handler.extract_keywords(segment.content, + self._config.max_keywords_per_chunk) + segment.keywords = list(keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) + self._save_dataset_keyword_table(keyword_table) + def update_segment_keywords_index(self, node_id: str, keywords: List[str]): keyword_table = self._get_dataset_keyword_table() keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) self._save_dataset_keyword_table(keyword_table) + class KeywordTableRetriever(BaseRetriever, BaseModel): index: KeywordTableIndex search_kwargs: dict = Field(default_factory=dict) diff --git a/api/fields/__init__.py b/api/fields/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py new file mode 100644 index 00000000000000..b370ac41e058e0 --- /dev/null +++ b/api/fields/app_fields.py @@ -0,0 +1,138 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +app_detail_kernel_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, +} + +related_app_list = { + 'data': fields.List(fields.Nested(app_detail_kernel_fields)), + 'total': fields.Integer, +} + +model_config_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), + 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), + 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), + 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), + 'more_like_this': fields.Raw(attribute='more_like_this_dict'), + 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), + 'model': fields.Raw(attribute='model_dict'), + 'user_input_form': fields.Raw(attribute='user_input_form_list'), + 'dataset_query_variable': fields.String, + 'pre_prompt': fields.String, + 'agent_mode': fields.Raw(attribute='agent_mode_dict'), +} + +app_detail_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'enable_site': fields.Boolean, + 'enable_api': fields.Boolean, + 'api_rpm': fields.Integer, + 'api_rph': fields.Integer, + 'is_demo': fields.Boolean, + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'created_at': TimestampField +} + +prompt_config_fields = { + 'prompt_template': fields.String, +} + +model_config_partial_fields = { + 'model': fields.Raw(attribute='model_dict'), + 'pre_prompt': fields.String, +} + +app_partial_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'enable_site': fields.Boolean, + 'enable_api': fields.Boolean, + 'is_demo': fields.Boolean, + 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), + 'created_at': TimestampField +} + +app_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') +} + +template_fields = { + 'name': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'mode': fields.String, + 'model_config': fields.Nested(model_config_fields), +} + +template_list_fields = { + 'data': fields.List(fields.Nested(template_fields)), +} + +site_fields = { + 'access_token': fields.String(attribute='code'), + 'code': fields.String, + 'title': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'default_language': fields.String, + 'customize_domain': fields.String, + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'customize_token_strategy': fields.String, + 'prompt_public': fields.Boolean, + 'app_base_url': fields.String, +} + +app_detail_fields_with_site = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'enable_site': fields.Boolean, + 'enable_api': fields.Boolean, + 'api_rpm': fields.Integer, + 'api_rph': fields.Integer, + 'is_demo': fields.Boolean, + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'site': fields.Nested(site_fields), + 'api_base_url': fields.String, + 'created_at': TimestampField +} + +app_site_fields = { + 'app_id': fields.String, + 'access_token': fields.String(attribute='code'), + 'code': fields.String, + 'title': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'default_language': fields.String, + 'customize_domain': fields.String, + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'customize_token_strategy': fields.String, + 'prompt_public': fields.Boolean +} \ No newline at end of file diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py new file mode 100644 index 00000000000000..dcfbe8a0694a98 --- /dev/null +++ b/api/fields/conversation_fields.py @@ -0,0 +1,182 @@ +from flask_restful import fields + +from libs.helper import TimestampField + + +class MessageTextField(fields.Raw): + def format(self, value): + return value[0]['text'] if value else '' + + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + +feedback_fields = { + 'rating': fields.String, + 'content': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account': fields.Nested(account_fields, allow_null=True), +} + +annotation_fields = { + 'content': fields.String, + 'account': fields.Nested(account_fields, allow_null=True), + 'created_at': TimestampField +} + +message_detail_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'message': fields.Raw, + 'message_tokens': fields.Integer, + 'answer': fields.String, + 'answer_tokens': fields.Integer, + 'provider_response_latency': fields.Float, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'feedbacks': fields.List(fields.Nested(feedback_fields)), + 'annotation': fields.Nested(annotation_fields, allow_null=True), + 'created_at': TimestampField +} + +feedback_stat_fields = { + 'like': fields.Integer, + 'dislike': fields.Integer +} + +model_config_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw, + 'model': fields.Raw, + 'user_input_form': fields.Raw, + 'pre_prompt': fields.String, + 'agent_mode': fields.Raw, +} + +simple_configs_fields = { + 'prompt_template': fields.String, +} + +simple_model_config_fields = { + 'model': fields.Raw(attribute='model_dict'), + 'pre_prompt': fields.String, +} + +simple_message_detail_fields = { + 'inputs': fields.Raw, + 'query': fields.String, + 'message': MessageTextField, + 'answer': fields.String, +} + +conversation_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_end_user_session_id': fields.String(), + 'from_account_id': fields.String, + 'read_at': TimestampField, + 'created_at': TimestampField, + 'annotation': fields.Nested(annotation_fields, allow_null=True), + 'model_config': fields.Nested(simple_model_config_fields), + 'user_feedback_stats': fields.Nested(feedback_stat_fields), + 'admin_feedback_stats': fields.Nested(feedback_stat_fields), + 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') +} + +conversation_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(conversation_fields), attribute='items') +} + +conversation_message_detail_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'created_at': TimestampField, + 'model_config': fields.Nested(model_config_fields), + 'message': fields.Nested(message_detail_fields, attribute='first_message'), +} + +simple_model_config_fields = { + 'model': fields.Raw(attribute='model_dict'), + 'pre_prompt': fields.String, +} + +conversation_with_summary_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_end_user_session_id': fields.String, + 'from_account_id': fields.String, + 'summary': fields.String(attribute='summary_or_query'), + 'read_at': TimestampField, + 'created_at': TimestampField, + 'annotated': fields.Boolean, + 'model_config': fields.Nested(simple_model_config_fields), + 'message_count': fields.Integer, + 'user_feedback_stats': fields.Nested(feedback_stat_fields), + 'admin_feedback_stats': fields.Nested(feedback_stat_fields) +} + +conversation_with_summary_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') +} + +conversation_detail_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'created_at': TimestampField, + 'annotated': fields.Boolean, + 'model_config': fields.Nested(model_config_fields), + 'message_count': fields.Integer, + 'user_feedback_stats': fields.Nested(feedback_stat_fields), + 'admin_feedback_stats': fields.Nested(feedback_stat_fields) +} + +simple_conversation_fields = { + 'id': fields.String, + 'name': fields.String, + 'inputs': fields.Raw, + 'status': fields.String, + 'introduction': fields.String, + 'created_at': TimestampField +} + +conversation_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(simple_conversation_fields)) +} + +conversation_with_model_config_fields = { + **simple_conversation_fields, + 'model_config': fields.Raw, +} + +conversation_with_model_config_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) +} \ No newline at end of file diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py new file mode 100644 index 00000000000000..6f3c920c85b60f --- /dev/null +++ b/api/fields/data_source_fields.py @@ -0,0 +1,65 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +integrate_icon_fields = { + 'type': fields.String, + 'url': fields.String, + 'emoji': fields.String +} + +integrate_page_fields = { + 'page_name': fields.String, + 'page_id': fields.String, + 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), + 'is_bound': fields.Boolean, + 'parent_id': fields.String, + 'type': fields.String +} + +integrate_workspace_fields = { + 'workspace_name': fields.String, + 'workspace_id': fields.String, + 'workspace_icon': fields.String, + 'pages': fields.List(fields.Nested(integrate_page_fields)) +} + +integrate_notion_info_list_fields = { + 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), +} + +integrate_icon_fields = { + 'type': fields.String, + 'url': fields.String, + 'emoji': fields.String +} + +integrate_page_fields = { + 'page_name': fields.String, + 'page_id': fields.String, + 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), + 'parent_id': fields.String, + 'type': fields.String +} + +integrate_workspace_fields = { + 'workspace_name': fields.String, + 'workspace_id': fields.String, + 'workspace_icon': fields.String, + 'pages': fields.List(fields.Nested(integrate_page_fields)), + 'total': fields.Integer +} + +integrate_fields = { + 'id': fields.String, + 'provider': fields.String, + 'created_at': TimestampField, + 'is_bound': fields.Boolean, + 'disabled': fields.Boolean, + 'link': fields.String, + 'source_info': fields.Nested(integrate_workspace_fields) +} + +integrate_list_fields = { + 'data': fields.List(fields.Nested(integrate_fields)), +} \ No newline at end of file diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py new file mode 100644 index 00000000000000..90af9e1fdd9b17 --- /dev/null +++ b/api/fields/dataset_fields.py @@ -0,0 +1,43 @@ +from flask_restful import fields +from libs.helper import TimestampField + +dataset_fields = { + 'id': fields.String, + 'name': fields.String, + 'description': fields.String, + 'permission': fields.String, + 'data_source_type': fields.String, + 'indexing_technique': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, +} + +dataset_detail_fields = { + 'id': fields.String, + 'name': fields.String, + 'description': fields.String, + 'provider': fields.String, + 'permission': fields.String, + 'data_source_type': fields.String, + 'indexing_technique': fields.String, + 'app_count': fields.Integer, + 'document_count': fields.Integer, + 'word_count': fields.Integer, + 'created_by': fields.String, + 'created_at': TimestampField, + 'updated_by': fields.String, + 'updated_at': TimestampField, + 'embedding_model': fields.String, + 'embedding_model_provider': fields.String, + 'embedding_available': fields.Boolean +} + +dataset_query_detail_fields = { + "id": fields.String, + "content": fields.String, + "source": fields.String, + "source_app_id": fields.String, + "created_by_role": fields.String, + "created_by": fields.String, + "created_at": TimestampField +} diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py new file mode 100644 index 00000000000000..94d905eafe00d3 --- /dev/null +++ b/api/fields/document_fields.py @@ -0,0 +1,76 @@ +from flask_restful import fields + +from fields.dataset_fields import dataset_fields +from libs.helper import TimestampField + +document_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'data_source_type': fields.String, + 'data_source_info': fields.Raw(attribute='data_source_info_dict'), + 'dataset_process_rule_id': fields.String, + 'name': fields.String, + 'created_from': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'tokens': fields.Integer, + 'indexing_status': fields.String, + 'error': fields.String, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'archived': fields.Boolean, + 'display_status': fields.String, + 'word_count': fields.Integer, + 'hit_count': fields.Integer, + 'doc_form': fields.String, +} + +document_with_segments_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'data_source_type': fields.String, + 'data_source_info': fields.Raw(attribute='data_source_info_dict'), + 'dataset_process_rule_id': fields.String, + 'name': fields.String, + 'created_from': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'tokens': fields.Integer, + 'indexing_status': fields.String, + 'error': fields.String, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'archived': fields.Boolean, + 'display_status': fields.String, + 'word_count': fields.Integer, + 'hit_count': fields.Integer, + 'completed_segments': fields.Integer, + 'total_segments': fields.Integer +} + +dataset_and_document_fields = { + 'dataset': fields.Nested(dataset_fields), + 'documents': fields.List(fields.Nested(document_fields)), + 'batch': fields.String +} + +document_status_fields = { + 'id': fields.String, + 'indexing_status': fields.String, + 'processing_started_at': TimestampField, + 'parsing_completed_at': TimestampField, + 'cleaning_completed_at': TimestampField, + 'splitting_completed_at': TimestampField, + 'completed_at': TimestampField, + 'paused_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField, + 'completed_segments': fields.Integer, + 'total_segments': fields.Integer, +} + +document_status_fields_list = { + 'data': fields.List(fields.Nested(document_status_fields)) +} \ No newline at end of file diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py new file mode 100644 index 00000000000000..dcab11a7ad8ee1 --- /dev/null +++ b/api/fields/file_fields.py @@ -0,0 +1,18 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +upload_config_fields = { + 'file_size_limit': fields.Integer, + 'batch_count_limit': fields.Integer +} + +file_fields = { + 'id': fields.String, + 'name': fields.String, + 'size': fields.Integer, + 'extension': fields.String, + 'mime_type': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, +} \ No newline at end of file diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py new file mode 100644 index 00000000000000..541e56a378dae4 --- /dev/null +++ b/api/fields/hit_testing_fields.py @@ -0,0 +1,41 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +document_fields = { + 'id': fields.String, + 'data_source_type': fields.String, + 'name': fields.String, + 'doc_type': fields.String, +} + +segment_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'document_id': fields.String, + 'content': fields.String, + 'answer': fields.String, + 'word_count': fields.Integer, + 'tokens': fields.Integer, + 'keywords': fields.List(fields.String), + 'index_node_id': fields.String, + 'index_node_hash': fields.String, + 'hit_count': fields.Integer, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'status': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'indexing_at': TimestampField, + 'completed_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField, + 'document': fields.Nested(document_fields), +} + +hit_testing_record_fields = { + 'segment': fields.Nested(segment_fields), + 'score': fields.Float, + 'tsne_position': fields.Raw +} \ No newline at end of file diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py new file mode 100644 index 00000000000000..79af54bdd6a794 --- /dev/null +++ b/api/fields/installed_app_fields.py @@ -0,0 +1,25 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +app_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String +} + +installed_app_fields = { + 'id': fields.String, + 'app': fields.Nested(app_fields), + 'app_owner_tenant_id': fields.String, + 'is_pinned': fields.Boolean, + 'last_used_at': TimestampField, + 'editable': fields.Boolean, + 'uninstallable': fields.Boolean, +} + +installed_app_list_fields = { + 'installed_apps': fields.List(fields.Nested(installed_app_fields)) +} \ No newline at end of file diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py new file mode 100644 index 00000000000000..c2f7193f65fb53 --- /dev/null +++ b/api/fields/message_fields.py @@ -0,0 +1,43 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +feedback_fields = { + 'rating': fields.String +} + +retriever_resource_fields = { + 'id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'dataset_id': fields.String, + 'dataset_name': fields.String, + 'document_id': fields.String, + 'document_name': fields.String, + 'data_source_type': fields.String, + 'segment_id': fields.String, + 'score': fields.Float, + 'hit_count': fields.Integer, + 'word_count': fields.Integer, + 'segment_position': fields.Integer, + 'index_node_hash': fields.String, + 'content': fields.String, + 'created_at': TimestampField +} + +message_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'answer': fields.String, + 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), + 'created_at': TimestampField +} + +message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_fields)) +} \ No newline at end of file diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py new file mode 100644 index 00000000000000..8c07c093213e93 --- /dev/null +++ b/api/fields/segment_fields.py @@ -0,0 +1,32 @@ +from flask_restful import fields +from libs.helper import TimestampField + +segment_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'document_id': fields.String, + 'content': fields.String, + 'answer': fields.String, + 'word_count': fields.Integer, + 'tokens': fields.Integer, + 'keywords': fields.List(fields.String), + 'index_node_id': fields.String, + 'index_node_hash': fields.String, + 'hit_count': fields.Integer, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'status': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'indexing_at': TimestampField, + 'completed_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField +} + +segment_list_response = { + 'data': fields.List(fields.Nested(segment_fields)), + 'has_more': fields.Boolean, + 'limit': fields.Integer +} diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py new file mode 100644 index 00000000000000..23ecabe9d22c7f --- /dev/null +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -0,0 +1,36 @@ +"""add_tenant_id_in_api_token + +Revision ID: 2e9819ca5b28 +Revises: 6e2cfb077b04 +Create Date: 2023-09-22 15:41:01.243183 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '2e9819ca5b28' +down_revision = 'ab23c11305d4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True)) + batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) + batch_op.drop_column('dataset_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True)) + batch_op.drop_index('api_token_tenant_idx') + batch_op.drop_column('tenant_id') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index dd7f09bf78a1db..9a0b8c6554cd10 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -629,12 +629,13 @@ class ApiToken(db.Model): __table_args__ = ( db.PrimaryKeyConstraint('id', name='api_token_pkey'), db.Index('api_token_app_id_type_idx', 'app_id', 'type'), - db.Index('api_token_token_idx', 'token', 'type') + db.Index('api_token_token_idx', 'token', 'type'), + db.Index('api_token_tenant_idx', 'tenant_id', 'type') ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=True) - dataset_id = db.Column(UUID, nullable=True) + tenant_id = db.Column(UUID, nullable=True) type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 78686112a58137..71581d47bf52ac 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -96,7 +96,7 @@ def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional embedding_model = None if indexing_technique == 'high_quality': embedding_model = ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id + tenant_id=tenant_id ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) @@ -477,6 +477,7 @@ def save_document_with_dataset_id(dataset: Dataset, document_data: dict, ) dataset.collection_binding_id = dataset_collection_binding.id + documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) if 'original_document_id' in document_data and document_data["original_document_id"]: @@ -626,6 +627,9 @@ def update_document_with_dataset_id(dataset: Dataset, document_data: dict, document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) if document.display_status != 'available': raise ValueError("Document is not available") + # update document name + if 'name' in document_data and document_data['name']: + document.name = document_data['name'] # save process rule if 'process_rule' in document_data and document_data['process_rule']: process_rule = document_data["process_rule"] @@ -767,7 +771,7 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun return dataset, documents, batch @classmethod - def document_create_args_validate(cls, args: dict): + def document_create_args_validate(cls, args: dict): if 'original_document_id' not in args or not args['original_document_id']: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) @@ -1014,6 +1018,66 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() return segment + @classmethod + def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): + embedding_model = None + if dataset.indexing_technique == 'high_quality': + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == document.id + ).scalar() + pre_segment_data_list = [] + segment_data_list = [] + for segment_item in segments: + content = segment_item['content'] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == 'high_quality' and embedding_model: + # calc embedding use tokens + tokens = embedding_model.get_num_tokens(content) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + status='completed', + indexing_at=datetime.datetime.utcnow(), + completed_at=datetime.datetime.utcnow(), + created_by=current_user.id + ) + if document.doc_form == 'qa_model': + segment_document.answer = segment_item['answer'] + db.session.add(segment_document) + segment_data_list.append(segment_document) + pre_segment_data = { + 'segment': segment_document, + 'keywords': segment_item['keywords'] + } + pre_segment_data_list.append(pre_segment_data) + + try: + # save vector index + VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) + except Exception as e: + logging.exception("create segment index failed") + for segment_document in segment_data_list: + segment_document.enabled = False + segment_document.disabled_at = datetime.datetime.utcnow() + segment_document.status = 'error' + segment_document.error = str(e) + db.session.commit() + return segment_data_list + @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = 'segment_{}_indexing'.format(segment.id) diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 1f01e033caa30e..5804f599fe63bf 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'app', 'completion', 'audio' + 'app', 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/file.py b/api/services/errors/file.py index 3674eca3e7ab39..29f3f44eece89d 100644 --- a/api/services/errors/file.py +++ b/api/services/errors/file.py @@ -3,3 +3,11 @@ class FileNotExistsError(BaseServiceError): pass + + +class FileTooLargeError(BaseServiceError): + description = "{message}" + + +class UnsupportedFileTypeError(BaseServiceError): + pass diff --git a/api/services/file_service.py b/api/services/file_service.py new file mode 100644 index 00000000000000..79e53738e04f3c --- /dev/null +++ b/api/services/file_service.py @@ -0,0 +1,123 @@ +import datetime +import hashlib +import time +import uuid + +from cachetools import TTLCache +from flask import request, current_app +from flask_login import current_user +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + +from core.data_loader.file_extractor import FileExtractor +from extensions.ext_storage import storage +from extensions.ext_database import db +from models.model import UploadFile +from services.errors.file import FileTooLargeError, UnsupportedFileTypeError + +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] +PREVIEW_WORDS_LIMIT = 3000 +cache = TTLCache(maxsize=None, ttl=30) + + +class FileService: + + @staticmethod + def upload_file(file: FileStorage) -> UploadFile: + # read file content + file_content = file.read() + # get file size + file_size = len(file_content) + + file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 + if file_size > file_size_limit: + message = f'File size exceeded. {file_size} > {file_size_limit}' + raise FileTooLargeError(message) + + extension = file.filename.split('.')[-1] + if extension.lower() not in ALLOWED_EXTENSIONS: + raise UnsupportedFileTypeError() + + # user uuid as file name + file_uuid = str(uuid.uuid4()) + file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension + + # save file to storage + storage.save(file_key, file_content) + + # save file to db + config = current_app.config + upload_file = UploadFile( + tenant_id=current_user.current_tenant_id, + storage_type=config['STORAGE_TYPE'], + key=file_key, + name=file.filename, + size=file_size, + extension=extension, + mime_type=file.mimetype, + created_by=current_user.id, + created_at=datetime.datetime.utcnow(), + used=False, + hash=hashlib.sha3_256(file_content).hexdigest() + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def upload_text(text: str, text_name: str) -> UploadFile: + # user uuid as file name + file_uuid = str(uuid.uuid4()) + file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' + + # save file to storage + storage.save(file_key, text.encode('utf-8')) + + # save file to db + config = current_app.config + upload_file = UploadFile( + tenant_id=current_user.current_tenant_id, + storage_type=config['STORAGE_TYPE'], + key=file_key, + name=text_name + '.txt', + size=len(text), + extension='txt', + mime_type='text/plain', + created_by=current_user.id, + created_at=datetime.datetime.utcnow(), + used=True, + used_by=current_user.id, + used_at=datetime.datetime.utcnow() + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def get_file_preview(file_id: str) -> str: + # get file storage key + key = file_id + request.path + cached_response = cache.get(key) + if cached_response and time.time() - cached_response['timestamp'] < cache.ttl: + return cached_response['response'] + + upload_file = db.session.query(UploadFile) \ + .filter(UploadFile.id == file_id) \ + .first() + + if not upload_file: + raise NotFound("File not found") + + # extract text from file + extension = upload_file.extension + if extension.lower() not in ALLOWED_EXTENSIONS: + raise UnsupportedFileTypeError() + + text = FileExtractor.load(upload_file, return_text=True) + text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + + return text diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3cb8eb0099c95b..45bf611fd4b092 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -35,6 +35,32 @@ def create_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentS else: index.add_texts([document]) + @classmethod + def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset): + documents = [] + for pre_segment_data in pre_segment_data_list: + segment = pre_segment_data['segment'] + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + documents.append(document) + + # save vector index + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts(documents, duplicate_check=True) + + # save keyword index + keyword_index = IndexBuilder.get_index(dataset, 'economy') + if keyword_index: + keyword_index.multi_create_segment_keywords(pre_segment_data_list) + @classmethod def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset): # update segment index task From 9dbb8acd4b8f4696b4ac3c56e9cf8b3e8fd6bbbf Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Wed, 27 Sep 2023 16:06:49 +0800 Subject: [PATCH 3/3] Feat/dataset support api service (#1240) Co-authored-by: Joel Co-authored-by: crazywoola <427733928@qq.com> --- web/app/(commonLayout)/datasets/ApiServer.tsx | 41 + web/app/(commonLayout)/datasets/Container.tsx | 60 ++ web/app/(commonLayout)/datasets/Datasets.tsx | 19 +- web/app/(commonLayout)/datasets/Doc.tsx | 28 + web/app/(commonLayout)/datasets/page.tsx | 8 +- .../datasets/template/template.en.mdx | 791 +++++++++++++++++ .../datasets/template/template.zh.mdx | 792 ++++++++++++++++++ web/app/components/base/tab-slider/index.tsx | 55 ++ .../develop/secret-key/secret-key-button.tsx | 2 +- .../develop/secret-key/secret-key-modal.tsx | 30 +- web/i18n/lang/app-api.zh.ts | 2 +- web/i18n/lang/dataset.en.ts | 2 + web/i18n/lang/dataset.zh.ts | 2 + web/service/datasets.ts | 20 + 14 files changed, 1832 insertions(+), 20 deletions(-) create mode 100644 web/app/(commonLayout)/datasets/ApiServer.tsx create mode 100644 web/app/(commonLayout)/datasets/Container.tsx create mode 100644 web/app/(commonLayout)/datasets/Doc.tsx create mode 100644 web/app/(commonLayout)/datasets/template/template.en.mdx create mode 100644 web/app/(commonLayout)/datasets/template/template.zh.mdx create mode 100644 web/app/components/base/tab-slider/index.tsx diff --git a/web/app/(commonLayout)/datasets/ApiServer.tsx b/web/app/(commonLayout)/datasets/ApiServer.tsx new file mode 100644 index 00000000000000..675dda354e5012 --- /dev/null +++ b/web/app/(commonLayout)/datasets/ApiServer.tsx @@ -0,0 +1,41 @@ +'use client' + +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import CopyFeedback from '@/app/components/base/copy-feedback' +import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button' +import { randomString } from '@/utils' + +type ApiServerProps = { + apiBaseUrl: string +} +const ApiServer: FC = ({ + apiBaseUrl, +}) => { + const { t } = useTranslation() + + return ( +
+
+
{t('appApi.apiServer')}
+
{apiBaseUrl}
+
+ +
+
+ {t('appApi.ok')} +
+ +
+ ) +} + +export default ApiServer diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx new file mode 100644 index 00000000000000..871ef04951be1a --- /dev/null +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -0,0 +1,60 @@ +'use client' + +import { useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import useSWR from 'swr' +import Datasets from './Datasets' +import DatasetFooter from './DatasetFooter' +import ApiServer from './ApiServer' +import Doc from './Doc' +import TabSlider from '@/app/components/base/tab-slider' +import { fetchDatasetApiBaseUrl } from '@/service/datasets' + +const Container = () => { + const { t } = useTranslation() + const options = [ + { + value: 'dataset', + text: t('dataset.datasets'), + }, + { + value: 'api', + text: t('dataset.datasetsApi'), + }, + ] + const [activeTab, setActiveTab] = useState('dataset') + const containerRef = useRef(null) + const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl) + + return ( +
+
+ setActiveTab(newActiveTab)} + options={options} + /> + { + activeTab === 'api' && ( + + ) + } +
+ { + activeTab === 'dataset' && ( +
+ + +
+ ) + } + { + activeTab === 'api' && ( + + ) + } +
+ ) +} + +export default Container diff --git a/web/app/(commonLayout)/datasets/Datasets.tsx b/web/app/(commonLayout)/datasets/Datasets.tsx index 3d3ecf91b7cd02..0bec39fe109b21 100644 --- a/web/app/(commonLayout)/datasets/Datasets.tsx +++ b/web/app/(commonLayout)/datasets/Datasets.tsx @@ -7,7 +7,7 @@ import NewDatasetCard from './NewDatasetCard' import DatasetCard from './DatasetCard' import type { DataSetListResponse } from '@/models/datasets' import { fetchDatasets } from '@/service/datasets' -import { useAppContext, useSelector } from '@/context/app-context' +import { useAppContext } from '@/context/app-context' const getKey = (pageIndex: number, previousPageData: DataSetListResponse) => { if (!pageIndex || previousPageData.has_more) @@ -15,11 +15,16 @@ const getKey = (pageIndex: number, previousPageData: DataSetListResponse) => { return null } -const Datasets = () => { +type Props = { + containerRef: React.RefObject +} + +const Datasets = ({ + containerRef, +}: Props) => { const { isCurrentWorkspaceManager } = useAppContext() const { data, isLoading, setSize, mutate } = useSWRInfinite(getKey, fetchDatasets, { revalidateFirstPage: false, revalidateAll: true }) const loadingStateRef = useRef(false) - const pageContainerRef = useSelector(state => state.pageContainerRef) const anchorRef = useRef(null) useEffect(() => { @@ -29,19 +34,19 @@ const Datasets = () => { useEffect(() => { const onScroll = debounce(() => { if (!loadingStateRef.current) { - const { scrollTop, clientHeight } = pageContainerRef.current! + const { scrollTop, clientHeight } = containerRef.current! const anchorOffset = anchorRef.current!.offsetTop if (anchorOffset - scrollTop - clientHeight < 100) setSize(size => size + 1) } }, 50) - pageContainerRef.current?.addEventListener('scroll', onScroll) - return () => pageContainerRef.current?.removeEventListener('scroll', onScroll) + containerRef.current?.addEventListener('scroll', onScroll) + return () => containerRef.current?.removeEventListener('scroll', onScroll) }, []) return ( -