From 0deb19e800f7751fcbc8f566cac6c36dabc1a3ee Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 10 Dec 2024 09:06:34 +0800 Subject: [PATCH] fix(app_generator_service): overload type hints (#11507) Signed-off-by: -LAN- --- .../app/apps/advanced_chat/app_generator.py | 37 ++++++++++++++- api/core/app/apps/agent_chat/app_generator.py | 37 ++++++++++++++- api/core/app/apps/chat/app_generator.py | 26 +++++++---- api/core/app/apps/completion/app_generator.py | 31 +++++++++---- api/core/app/apps/workflow/app_generator.py | 46 ++++++++++++++++++- 5 files changed, 155 insertions(+), 22 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index bd4fd9cd3b2646..6200299d21c869 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -36,6 +36,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): _dialogue_count: int + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload def generate( self, app_model: App, @@ -44,7 +67,17 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... + + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ Generate App response. diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b659c1855624b5..b391169e3dbe5c 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -2,7 +2,7 @@ import threading import uuid from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -28,6 +28,39 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + def generate( self, *, @@ -36,7 +69,7 @@ def generate( args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ): """ Generate App response. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 6a9e1623881e17..5b8debaaae6a56 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,7 +1,7 @@ import logging import threading import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app @@ -34,9 +34,9 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[True] = True, + streaming: Literal[True], ) -> Generator[str, None, None]: ... @overload @@ -44,19 +44,29 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[False] = False, - ) -> dict: ... + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... def generate( self, app_model: App, user: Union[Account, EndUser], - args: Any, + args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + ): """ Generate App response. diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 324e837a1c8f29..14fd33dd398927 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,7 +1,7 @@ import logging import threading import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app @@ -34,9 +34,9 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[True] = True, + streaming: Literal[True], ) -> Generator[str, None, None]: ... @overload @@ -44,14 +44,29 @@ def generate( self, app_model: App, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[False] = False, - ) -> dict: ... + streaming: Literal[False], + ) -> Mapping[str, Any]: ... + @overload def generate( - self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True - ) -> Union[dict, Generator[str, None, None]]: + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ Generate App response. diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 7acf05326efe3f..dc4ee9e566a2f3 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -30,6 +30,35 @@ class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Mapping[str, Any]: ... + + @overload def generate( self, *, @@ -41,7 +70,20 @@ def generate( streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ) -> Mapping[str, Any] | Generator[str, None, None]: + ) -> Mapping[str, Any] | Generator[str, None, None]: ... + + def generate( + self, + *, + app_model: App, + workflow: Workflow, + user: Account | EndUser, + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ): files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files