Skip to content

Commit

Permalink
fix(app_generator): correct overload
Browse files Browse the repository at this point in the history
Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 committed Dec 9, 2024
1 parent fbb85c5 commit 9a76007
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 11 deletions.
37 changes: 35 additions & 2 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union, overload

from flask import Flask, current_app
from pydantic import ValidationError
Expand Down Expand Up @@ -36,6 +36,29 @@
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int

@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...

@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...

@overload
def generate(
self,
app_model: App,
Expand All @@ -44,7 +67,17 @@ def generate(
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...

def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Generate App response.
Expand Down
37 changes: 35 additions & 2 deletions api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Literal, Union, overload

from flask import Flask, current_app
from pydantic import ValidationError
Expand All @@ -28,6 +28,39 @@


class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...

@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...

@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...

def generate(
self,
*,
Expand All @@ -36,7 +69,7 @@ def generate(
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
):
"""
Generate App response.
Expand Down
4 changes: 2 additions & 2 deletions api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate(
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True] = True,
streaming: Literal[True],
) -> Generator[str, None, None]: ...

@overload
Expand All @@ -46,7 +46,7 @@ def generate(
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False] = False,
streaming: Literal[False],
) -> Mapping[str, Any]: ...

@overload
Expand Down
6 changes: 3 additions & 3 deletions api/core/app/apps/completion/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate(
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True] = True,
streaming: Literal[True],
) -> Generator[str, None, None]: ...

@overload
Expand All @@ -46,7 +46,7 @@ def generate(
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False] = False,
streaming: Literal[False],
) -> Mapping[str, Any]: ...

@overload
Expand All @@ -57,7 +57,7 @@ def generate(
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
) -> Mapping[str, Any] | Generator[str, None, None]: ...

def generate(
self,
Expand Down
46 changes: 44 additions & 2 deletions api/core/app/apps/workflow/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union, overload

from flask import Flask, current_app
from pydantic import ValidationError
Expand All @@ -30,6 +30,35 @@


class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...

@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any]: ...

@overload
def generate(
self,
*,
Expand All @@ -41,7 +70,20 @@ def generate(
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str, None, None]: ...

def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
files: Sequence[Mapping[str, Any]] = args.get("files") or []

# parse files
Expand Down

0 comments on commit 9a76007

Please sign in to comment.