Skip to content

Commit

Permalink
Merge branch 'feat/generate-conv-name-optional' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Oct 20, 2023
2 parents 0d6853b + 30d05d2 commit 7d22074
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 38 deletions.
2 changes: 2 additions & 0 deletions api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def post(self, app_id):
args = parser.parse_args()

streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False

account = flask_login.current_user

Expand Down Expand Up @@ -120,6 +121,7 @@ def post(self, app_id):
args = parser.parse_args()

streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False

account = flask_login.current_user

Expand Down
10 changes: 10 additions & 0 deletions api/controllers/console/explore/completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime
from typing import Generator, Union

from flask import Response, stream_with_context
Expand All @@ -17,6 +18,7 @@
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from extensions.ext_database import db
from libs.helper import uuid_value
from services.completion_service import CompletionService

Expand All @@ -37,6 +39,10 @@ def post(self, installed_app):
args = parser.parse_args()

streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False

installed_app.last_used_at = datetime.utcnow()
db.session.commit()

try:
response = CompletionService.completion(
Expand Down Expand Up @@ -97,6 +103,10 @@ def post(self, installed_app):
args = parser.parse_args()

streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False

installed_app.last_used_at = datetime.utcnow()
db.session.commit()

try:
response = CompletionService.completion(
Expand Down
12 changes: 10 additions & 2 deletions api/controllers/console/explore/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def get(self, installed_app):
user=current_user,
last_id=args['last_id'],
limit=args['limit'],
pinned=pinned
pinned=pinned,
exclude_debug_conversation=True
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
Expand Down Expand Up @@ -72,10 +73,17 @@ def post(self, installed_app, c_id):

parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()

try:
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

Expand Down
5 changes: 3 additions & 2 deletions api/controllers/console/explore/installed_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def get(self):
}
for installed_app in installed_apps
]
installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
if app['last_used_at'] is not None else datetime.min))
installed_apps.sort(key=lambda app: (-app['is_pinned'],
app['last_used_at'] is None,
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))

return {'installed_apps': installed_apps}

Expand Down
2 changes: 2 additions & 0 deletions api/controllers/console/universal_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def post(self, universal_app):
del args['model']
del args['tools']

args['auto_generate_name'] = False

try:
response = CompletionService.completion(
app_model=app_model,
Expand Down
11 changes: 9 additions & 2 deletions api/controllers/console/universal_chat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,18 @@ def post(self, universal_app, c_id):
conversation_id = str(c_id)

parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()

try:
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

Expand Down
5 changes: 4 additions & 1 deletion api/controllers/service_api/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def post(self, app_model, end_user):
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

args['auto_generate_name'] = False

try:
response = CompletionService.completion(
app_model=app_model,
user=end_user,
args=args,
from_source='api',
streaming=streaming
streaming=streaming,
)

return compact_response(response)
Expand Down Expand Up @@ -94,6 +96,7 @@ def post(self, app_model, end_user):
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')

args = parser.parse_args()

Expand Down
11 changes: 9 additions & 2 deletions api/controllers/service_api/app/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,22 @@ def post(self, app_model, end_user, c_id):
conversation_id = str(c_id)

parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

try:
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
end_user,
args['name'],
args['auto_generate']
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

Expand Down
2 changes: 2 additions & 0 deletions api/controllers/web/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def post(self, app_model, end_user):
args = parser.parse_args()

streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False

try:
response = CompletionService.completion(
Expand Down Expand Up @@ -95,6 +96,7 @@ def post(self, app_model, end_user):
args = parser.parse_args()

streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False

try:
response = CompletionService.completion(
Expand Down
11 changes: 9 additions & 2 deletions api/controllers/web/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ def post(self, app_model, end_user, c_id):
conversation_id = str(c_id)

parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()

try:
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
end_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

Expand Down
6 changes: 4 additions & 2 deletions api/core/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
is_override: bool = False, retriever_from: str = 'dev'):
is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
"""
errors: ProviderTokenNotInitError
"""
Expand Down Expand Up @@ -58,7 +59,8 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer
inputs=inputs,
query=query,
streaming=streaming,
model_instance=final_model_instance
model_instance=final_model_instance,
auto_generate_name=auto_generate_name
)

rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
Expand Down
9 changes: 6 additions & 3 deletions api/core/conversation_message_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
conversation: Optional[Conversation] = None, is_override: bool = False,
auto_generate_name: bool = True):
self.start_at = time.perf_counter()

self.task_id = task_id
Expand All @@ -45,6 +46,7 @@ def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, use
self.message = None

self.retriever_resource = None
self.auto_generate_name = auto_generate_name

self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
Expand Down Expand Up @@ -100,7 +102,7 @@ def init(self):
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
name='',
name='New conversation',
inputs=self.inputs,
introduction=introduction,
system_instruction=system_instruction,
Expand Down Expand Up @@ -176,7 +178,8 @@ def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_was_created.send(
self.message,
conversation=self.conversation,
is_first_message=self.is_new_conversation
is_first_message=self.is_new_conversation,
auto_generate_name=self.auto_generate_name
)

if not by_stopped:
Expand Down
8 changes: 6 additions & 2 deletions api/core/generator/llm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
def generate_conversation_name(cls, tenant_id: str, query):
prompt = CONVERSATION_TITLE_PROMPT

if len(query) > 2000:
Expand All @@ -40,8 +40,12 @@ def generate_conversation_name(cls, tenant_id: str, query, answer):

result_dict = json.loads(answer)
answer = result_dict['Your Output']
name = answer.strip()

return answer.strip()
if len(name) > 75:
name = name[:75] + '...'

return name

@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from core.generator.llm_generator import LLMGenerator
from events.message_event import message_was_created
from extensions.ext_database import db
Expand All @@ -10,23 +8,19 @@ def handle(sender, **kwargs):
message = sender
conversation = kwargs.get('conversation')
is_first_message = kwargs.get('is_first_message')
auto_generate_name = kwargs.get('auto_generate_name', True)

if is_first_message:
if auto_generate_name and is_first_message:
if conversation.mode == 'chat':
app_model = conversation.app
if not app_model:
return

# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)

if len(name) > 75:
name = name[:75] + '...'

name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
conversation.name = name
except:
conversation.name = 'New conversation'
pass

db.session.add(conversation)
db.session.commit()
13 changes: 9 additions & 4 deletions api/services/completion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
# is streaming mode
inputs = args['inputs']
query = args['query']
auto_generate_name = args['auto_generate_name'] \
if 'auto_generate_name' in args else True

if app_model.mode != 'completion' and not query:
raise ValueError('query is required')
Expand Down Expand Up @@ -149,7 +151,8 @@ def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
'detached_conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name
})

generate_worker_thread.start()
Expand All @@ -174,7 +177,7 @@ def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
retriever_from: str = 'dev', auto_generate_name: bool = True):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
Expand All @@ -197,7 +200,8 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from
retriever_from=retriever_from,
auto_generate_name=auto_generate_name
)
except ConversationTaskStoppedException:
pass
Expand Down Expand Up @@ -291,7 +295,8 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from
'retriever_from': retriever_from,
'auto_generate_name': False
})

generate_worker_thread.start()
Expand Down
Loading

0 comments on commit 7d22074

Please sign in to comment.