From 30d05d223d1974a6c6653b4d2f84f2d2e8a6e6e5 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 20 Oct 2023 17:54:01 +0800 Subject: [PATCH] feat: conversation name generation optional fix: explore app list sort rule feat: disable debug conversations in explore apps --- api/controllers/console/app/completion.py | 2 + api/controllers/console/explore/completion.py | 10 +++++ .../console/explore/conversation.py | 12 +++++- .../console/explore/installed_app.py | 5 ++- .../console/universal_chat/chat.py | 2 + .../console/universal_chat/conversation.py | 11 ++++- api/controllers/service_api/app/completion.py | 5 ++- .../service_api/app/conversation.py | 11 ++++- api/controllers/web/completion.py | 2 + api/controllers/web/conversation.py | 11 ++++- api/core/completion.py | 6 ++- api/core/conversation_message_task.py | 9 +++-- api/core/generator/llm_generator.py | 8 +++- ...rsation_name_when_first_message_created.py | 14 ++----- api/services/completion_service.py | 13 ++++-- api/services/conversation_service.py | 40 +++++++++++++++++-- api/services/web_conversation_service.py | 6 ++- 17 files changed, 129 insertions(+), 38 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 1da7bd8f2ce974..fd55de55cdd658 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -46,6 +46,7 @@ def post(self, app_id): args = parser.parse_args() streaming = args['response_mode'] != 'blocking' + args['auto_generate_name'] = False account = flask_login.current_user @@ -120,6 +121,7 @@ def post(self, app_id): args = parser.parse_args() streaming = args['response_mode'] != 'blocking' + args['auto_generate_name'] = False account = flask_login.current_user diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index bdf1f3b907dfdd..96820c0b0e2069 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,6 +1,7 @@ # -*- coding:utf-8 -*- import json import logging +from datetime import datetime from typing import Generator, Union from flask import Response, stream_with_context @@ -17,6 +18,7 @@ from core.conversation_message_task import PubHandler from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from extensions.ext_database import db from libs.helper import uuid_value from services.completion_service import CompletionService @@ -37,6 +39,10 @@ def post(self, installed_app): args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False + + installed_app.last_used_at = datetime.utcnow() + db.session.commit() try: response = CompletionService.completion( @@ -97,6 +103,10 @@ def post(self, installed_app): args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False + + installed_app.last_used_at = datetime.utcnow() + db.session.commit() try: response = CompletionService.completion( diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index aa56437f33166b..be1c875d290c5e 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -38,7 +38,8 @@ def get(self, installed_app): user=current_user, last_id=args['last_id'], limit=args['limit'], - pinned=pinned + pinned=pinned, + exclude_debug_conversation=True ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -72,10 +73,17 @@ def post(self, installed_app, c_id): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() try: - return ConversationService.rename(app_model, conversation_id, current_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + current_user, + args['name'], + args['auto_generate'] + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index d36d02828ee5f9..d7ee991663ed8b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -39,8 +39,9 @@ def get(self): } for installed_app in installed_apps ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at'] - if app['last_used_at'] is not None else datetime.min)) + installed_apps.sort(key=lambda app: (-app['is_pinned'], + app['last_used_at'] is None, + -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) return {'installed_apps': installed_apps} diff --git a/api/controllers/console/universal_chat/chat.py b/api/controllers/console/universal_chat/chat.py index 61ba50325eb9ba..5e7ad826e39882 100644 --- a/api/controllers/console/universal_chat/chat.py +++ b/api/controllers/console/universal_chat/chat.py @@ -60,6 +60,8 @@ def post(self, universal_app): del args['model'] del args['tools'] + args['auto_generate_name'] = False + try: response = CompletionService.completion( app_model=app_model, diff --git a/api/controllers/console/universal_chat/conversation.py b/api/controllers/console/universal_chat/conversation.py index c0782cb81a91ac..a85e392c25f837 100644 --- a/api/controllers/console/universal_chat/conversation.py +++ b/api/controllers/console/universal_chat/conversation.py @@ -65,11 +65,18 @@ def post(self, universal_app, c_id): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() try: - return ConversationService.rename(app_model, conversation_id, current_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + current_user, + args['name'], + args['auto_generate'] + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index a339322ea89a8e..8cabe376610c7c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -39,13 +39,15 @@ def post(self, app_model, end_user): if end_user is None and args['user'] is not None: end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + args['auto_generate_name'] = False + try: response = CompletionService.completion( app_model=app_model, user=end_user, args=args, from_source='api', - streaming=streaming + streaming=streaming, ) return compact_response(response) @@ -94,6 +96,7 @@ def post(self, app_model, end_user): parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('user', type=str, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json') args = parser.parse_args() diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index e111ea2ebc51ee..2fdddef8a10081 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -65,15 +65,22 @@ def post(self, app_model, end_user, c_id): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') parser.add_argument('user', type=str, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() if end_user is None and args['user'] is not None: end_user = create_or_update_end_user_for_user_id(app_model, args['user']) try: - return ConversationService.rename(app_model, conversation_id, end_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + end_user, + args['name'], + args['auto_generate'] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 79c0c542d17852..eeeec36d14b52f 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -36,6 +36,7 @@ def post(self, app_model, end_user): args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False try: response = CompletionService.completion( @@ -95,6 +96,7 @@ def post(self, app_model, end_user): args = parser.parse_args() streaming = args['response_mode'] == 'streaming' + args['auto_generate_name'] = False try: response = CompletionService.completion( diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index ce089ca39572f5..f6bb96bf18e5b6 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -67,11 +67,18 @@ def post(self, app_model, end_user, c_id): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('name', type=str, required=False, location='json') + parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json') args = parser.parse_args() try: - return ConversationService.rename(app_model, conversation_id, end_user, args['name']) + return ConversationService.rename( + app_model, + conversation_id, + end_user, + args['name'], + args['auto_generate'] + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/core/completion.py b/api/core/completion.py index 57e18199271ccb..4e38af6b0d8f0f 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -24,7 +24,8 @@ class Completion: @classmethod def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, - is_override: bool = False, retriever_from: str = 'dev'): + is_override: bool = False, retriever_from: str = 'dev', + auto_generate_name: bool = True): """ errors: ProviderTokenNotInitError """ @@ -58,7 +59,8 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer inputs=inputs, query=query, streaming=streaming, - model_instance=final_model_instance + model_instance=final_model_instance, + auto_generate_name=auto_generate_name ) rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 3be6ffaee37bb0..7c945fba6809c2 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -22,7 +22,8 @@ class ConversationMessageTask: def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, - conversation: Optional[Conversation] = None, is_override: bool = False): + conversation: Optional[Conversation] = None, is_override: bool = False, + auto_generate_name: bool = True): self.start_at = time.perf_counter() self.task_id = task_id @@ -45,6 +46,7 @@ def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, use self.message = None self.retriever_resource = None + self.auto_generate_name = auto_generate_name self.model_dict = self.app_model_config.model_dict self.provider_name = self.model_dict.get('provider') @@ -100,7 +102,7 @@ def init(self): model_id=self.model_name, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=self.mode, - name='', + name='New conversation', inputs=self.inputs, introduction=introduction, system_instruction=system_instruction, @@ -176,7 +178,8 @@ def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): message_was_created.send( self.message, conversation=self.conversation, - is_first_message=self.is_new_conversation + is_first_message=self.is_new_conversation, + auto_generate_name=self.auto_generate_name ) if not by_stopped: diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index a6699f32d7a127..87a934e55d54e6 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -16,7 +16,7 @@ class LLMGenerator: @classmethod - def generate_conversation_name(cls, tenant_id: str, query, answer): + def generate_conversation_name(cls, tenant_id: str, query): prompt = CONVERSATION_TITLE_PROMPT if len(query) > 2000: @@ -40,8 +40,12 @@ def generate_conversation_name(cls, tenant_id: str, query, answer): result_dict = json.loads(answer) answer = result_dict['Your Output'] + name = answer.strip() - return answer.strip() + if len(name) > 75: + name = name[:75] + '...' + + return name @classmethod def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 4176f9dbbb7681..b35e67969b6dc6 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,5 +1,3 @@ -import logging - from core.generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db @@ -10,8 +8,9 @@ def handle(sender, **kwargs): message = sender conversation = kwargs.get('conversation') is_first_message = kwargs.get('is_first_message') + auto_generate_name = kwargs.get('auto_generate_name', True) - if is_first_message: + if auto_generate_name and is_first_message: if conversation.mode == 'chat': app_model = conversation.app if not app_model: @@ -19,14 +18,9 @@ def handle(sender, **kwargs): # generate conversation name try: - name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer) - - if len(name) > 75: - name = name[:75] + '...' - + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) conversation.name = name except: - conversation.name = 'New conversation' + pass - db.session.add(conversation) db.session.commit() diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 54b150d155fb82..82303390a7b656 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -34,6 +34,8 @@ def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any, # is streaming mode inputs = args['inputs'] query = args['query'] + auto_generate_name = args['auto_generate_name'] \ + if 'auto_generate_name' in args else True if app_model.mode != 'completion' and not query: raise ValueError('query is required') @@ -149,7 +151,8 @@ def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any, 'detached_conversation': conversation, 'streaming': streaming, 'is_model_config_override': is_model_config_override, - 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev' + 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev', + 'auto_generate_name': auto_generate_name }) generate_worker_thread.start() @@ -174,7 +177,7 @@ def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]): def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig, query: str, inputs: dict, detached_user: Union[Account, EndUser], detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool, - retriever_from: str = 'dev'): + retriever_from: str = 'dev', auto_generate_name: bool = True): with flask_app.app_context(): # fixed the state of the model object when it detached from the original session user = db.session.merge(detached_user) @@ -197,7 +200,8 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m conversation=conversation, streaming=streaming, is_override=is_model_config_override, - retriever_from=retriever_from + retriever_from=retriever_from, + auto_generate_name=auto_generate_name ) except ConversationTaskStoppedException: pass @@ -291,7 +295,8 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], 'detached_conversation': None, 'streaming': streaming, 'is_model_config_override': True, - 'retriever_from': retriever_from + 'retriever_from': retriever_from, + 'auto_generate_name': False }) generate_worker_thread.start() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 45f18ec4cfab7b..0872a232f09eaf 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,17 +1,20 @@ from typing import Union, Optional +from core.generator.llm_generator import LLMGenerator from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db from models.account import Account -from models.model import Conversation, App, EndUser +from models.model import Conversation, App, EndUser, Message from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError +from services.errors.message import MessageNotExistsError class ConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], last_id: Optional[str], limit: int, - include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + include_ids: Optional[list] = None, exclude_ids: Optional[list] = None, + exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -29,6 +32,9 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | En if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + if exclude_debug_conversation: + base_query = base_query.filter(Conversation.override_model_configs == None) + if last_id: last_conversation = base_query.filter( Conversation.id == last_id, @@ -63,10 +69,36 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | En @classmethod def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account | EndUser]], name: str): + user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool): conversation = cls.get_conversation(app_model, conversation_id, user) - conversation.name = name + if auto_generate: + return cls.auto_generate_name(app_model, conversation) + else: + conversation.name = name + db.session.commit() + + return conversation + + @classmethod + def auto_generate_name(cls, app_model: App, conversation: Conversation): + # get conversation first message + message = db.session.query(Message) \ + .filter( + Message.app_id == app_model.id, + Message.conversation_id == conversation.id + ).order_by(Message.created_at.asc()).first() + + if not message: + raise MessageNotExistsError() + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) + conversation.name = name + except: + pass + db.session.commit() return conversation diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 231083db192528..3c521909dca177 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,7 +11,8 @@ class WebConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], - last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination: + last_id: Optional[str], limit: int, pinned: Optional[bool] = None, + exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: @@ -32,7 +33,8 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | En last_id=last_id, limit=limit, include_ids=include_ids, - exclude_ids=exclude_ids + exclude_ids=exclude_ids, + exclude_debug_conversation=exclude_debug_conversation ) @classmethod