diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index bd4fd9cd3b2646..6200299d21c869 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -36,6 +36,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): _dialogue_count: int + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload def generate( self, app_model: App, @@ -44,7 +67,17 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... + + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ Generate App response. diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b659c1855624b5..b391169e3dbe5c 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -2,7 +2,7 @@ import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -28,6 +28,39 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + def generate( self, *, @@ -36,7 +69,7 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ): """ Generate App response. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index fdaf8a6607d3b5..5b8debaaae6a56 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -36,7 +36,7 @@ def generate( user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, - streaming: Literal[True] = True, + streaming: Literal[True], ) -> Generator[str, None, None]: ... @overload @@ -46,7 +46,7 @@ def generate( user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, - streaming: Literal[False] = False, + streaming: Literal[False], ) -> Mapping[str, Any]: ... @overload diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 7eccc06b44614e..14fd33dd398927 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -36,7 +36,7 @@ def generate( user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, - streaming: Literal[True] = True, + streaming: Literal[True], ) -> Generator[str, None, None]: ... @overload @@ -46,7 +46,7 @@ def generate( user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, - streaming: Literal[False] = False, + streaming: Literal[False], ) -> Mapping[str, Any]: ... @overload @@ -57,7 +57,7 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool, - ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... + ) -> Mapping[str, Any] | Generator[str, None, None]: ... def generate( self, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 7acf05326efe3f..dc4ee9e566a2f3 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -30,6 +30,35 @@ class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Mapping[str, Any]: ... + + @overload def generate( self, *, @@ -41,7 +70,20 @@ def generate( streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ): files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files