Skip to content

Commit

Permalink
Merge branch 'feat/workflow-backend' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Mar 16, 2024
2 parents cfb80ca + 11dfdb2 commit 41ca463
Show file tree
Hide file tree
Showing 21 changed files with 663 additions and 70 deletions.
7 changes: 5 additions & 2 deletions api/controllers/console/explore/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from extensions.ext_database import db
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService


Expand Down Expand Up @@ -95,7 +96,8 @@ def post(self, installed_app, task_id):
class ChatApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -148,7 +150,8 @@ def post(self, installed_app):
class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
Expand Down
16 changes: 11 additions & 5 deletions api/controllers/console/explore/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService
Expand All @@ -18,7 +19,8 @@ class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -47,7 +49,8 @@ def get(self, installed_app):
class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -65,7 +68,8 @@ class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -91,7 +95,8 @@ class ConversationPinApi(InstalledAppResource):

def patch(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -107,7 +112,8 @@ def patch(self, installed_app, c_id):
class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/explore/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class NotCompletionAppError(BaseHTTPException):

class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Not Chat App"
description = "App mode is invalid."
code = 400


Expand Down
9 changes: 6 additions & 3 deletions api/controllers/console/explore/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
Expand All @@ -38,7 +39,8 @@ class MessageListApi(InstalledAppResource):
def get(self, installed_app):
app_model = installed_app.app

if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -118,8 +120,9 @@ def get(self, installed_app, message_id):
class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
if app_model.mode != 'chat':
raise NotCompletionAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

message_id = str(message_id)

Expand Down
8 changes: 5 additions & 3 deletions api/controllers/service_api/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService


Expand Down Expand Up @@ -90,7 +90,8 @@ def post(self, app_model: App, end_user: EndUser, task_id):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -141,7 +142,8 @@ def post(self, app_model: App, end_user: EndUser):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
Expand Down
12 changes: 8 additions & 4 deletions api/controllers/service_api/app/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService


Expand All @@ -17,7 +17,8 @@ class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand All @@ -30,11 +31,13 @@ def get(self, app_model: App, end_user: EndUser):
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")


class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -51,7 +54,8 @@ class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/service_api/app/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class NotCompletionAppError(BaseHTTPException):

class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Please check if your Chat app mode matches the right API route."
description = "Please check if your app mode matches the right API route."
code = 400


Expand Down
8 changes: 5 additions & 3 deletions api/controllers/service_api/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.message_service import MessageService


Expand Down Expand Up @@ -71,7 +71,8 @@ class MessageListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -110,7 +111,8 @@ class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

try:
Expand Down
7 changes: 5 additions & 2 deletions api/controllers/web/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService


Expand Down Expand Up @@ -88,7 +89,8 @@ def post(self, app_model, end_user, task_id):

class ChatApi(WebApiResource):
def post(self, app_model, end_user):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -138,7 +140,8 @@ def post(self, app_model, end_user):

class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
Expand Down
16 changes: 11 additions & 5 deletions api/controllers/web/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService
Expand All @@ -16,7 +17,8 @@ class ConversationListApi(WebApiResource):

@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand All @@ -43,7 +45,8 @@ def get(self, app_model, end_user):

class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -60,7 +63,8 @@ class ConversationRenameApi(WebApiResource):

@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -85,7 +89,8 @@ def post(self, app_model, end_user, c_id):
class ConversationPinApi(WebApiResource):

def patch(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand All @@ -100,7 +105,8 @@ def patch(self, app_model, end_user, c_id):

class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

conversation_id = str(c_id)
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/web/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class NotCompletionAppError(BaseHTTPException):

class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
description = "Please check if your Chat app mode matches the right API route."
description = "Please check if your app mode matches the right API route."
code = 400


Expand Down
7 changes: 5 additions & 2 deletions api/controllers/web/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from fields.message_fields import agent_thought_fields
from libs import helper
from libs.helper import TimestampField, uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
Expand Down Expand Up @@ -76,7 +77,8 @@ class MessageListApi(WebApiResource):

@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()

parser = reqparse.RequestParser()
Expand Down Expand Up @@ -154,7 +156,8 @@ def get(self, app_model, end_user, message_id):

class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
if app_model.mode != 'chat':
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotCompletionAppError()

message_id = str(message_id)
Expand Down
6 changes: 4 additions & 2 deletions api/core/workflow/nodes/knowledge_retrieval/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pydantic import BaseModel

from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector


class RerankingModelConfig(BaseModel):
Expand Down Expand Up @@ -44,7 +43,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
variables: list[VariableSelector]
title: str
desc: str
type: str = 'knowledge-retrieval'
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal['single', 'multiple']
multiple_retrieval_config: MultipleRetrievalConfig
Expand Down
Loading

0 comments on commit 41ca463

Please sign in to comment.