From 493834d45dd117f8324dd746fd28a4d1b46429ce Mon Sep 17 00:00:00 2001 From: yihong Date: Wed, 18 Dec 2024 17:36:36 +0800 Subject: [PATCH 01/91] ci: add config ci more disscuss check #11706 (#11752) Signed-off-by: yihong0618 --- .github/workflows/api-tests.yml | 3 + dev/pytest/pytest_config_tests.py | 111 ++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 dev/pytest/pytest_config_tests.py diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index e1c0bf33a4ff31..2cd0b2a7d430de 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -50,6 +50,9 @@ jobs: - name: Run ModelRuntime run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh + - name: Run dify config tests + run: poetry run -C api python dev/pytest/pytest_config_tests.py + - name: Run Tool run: poetry run -C api bash dev/pytest/pytest_tools.sh diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py new file mode 100644 index 00000000000000..08adc9ebe999b1 --- /dev/null +++ b/dev/pytest/pytest_config_tests.py @@ -0,0 +1,111 @@ +import yaml # type: ignore +from dotenv import dotenv_values +from pathlib import Path + +BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { + "APP_MAX_EXECUTION_TIME", + "BATCH_UPLOAD_LIMIT", + "CELERY_BEAT_SCHEDULER_TIME", + "CODE_EXECUTION_API_KEY", + "HTTP_REQUEST_MAX_CONNECT_TIMEOUT", + "HTTP_REQUEST_MAX_READ_TIMEOUT", + "HTTP_REQUEST_MAX_WRITE_TIMEOUT", + "KEYWORD_DATA_SOURCE_TYPE", + "LOGIN_LOCKOUT_DURATION", + "LOG_FORMAT", + "OCI_ACCESS_KEY", + "OCI_BUCKET_NAME", + "OCI_ENDPOINT", + "OCI_REGION", + "OCI_SECRET_KEY", + "REDIS_DB", + "RESEND_API_URL", + "RESPECT_XFORWARD_HEADERS_ENABLED", + "SENTRY_DSN", + "SSRF_DEFAULT_CONNECT_TIME_OUT", + "SSRF_DEFAULT_MAX_RETRIES", + "SSRF_DEFAULT_READ_TIME_OUT", + "SSRF_DEFAULT_TIME_OUT", + "SSRF_DEFAULT_WRITE_TIME_OUT", + "UPSTASH_VECTOR_TOKEN", + "UPSTASH_VECTOR_URL", + "USING_UGC_INDEX", + "WEAVIATE_BATCH_SIZE", + "WEAVIATE_GRPC_ENABLED", +} + +BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { + "BATCH_UPLOAD_LIMIT", + "CELERY_BEAT_SCHEDULER_TIME", + "HTTP_REQUEST_MAX_CONNECT_TIMEOUT", + "HTTP_REQUEST_MAX_READ_TIMEOUT", + "HTTP_REQUEST_MAX_WRITE_TIMEOUT", + "KEYWORD_DATA_SOURCE_TYPE", + "LOGIN_LOCKOUT_DURATION", + "LOG_FORMAT", + "OPENDAL_FS_ROOT", + "OPENDAL_S3_ACCESS_KEY_ID", + "OPENDAL_S3_BUCKET", + "OPENDAL_S3_ENDPOINT", + "OPENDAL_S3_REGION", + "OPENDAL_S3_ROOT", + "OPENDAL_S3_SECRET_ACCESS_KEY", + "OPENDAL_S3_SERVER_SIDE_ENCRYPTION", + "PGVECTOR_MAX_CONNECTION", + "PGVECTOR_MIN_CONNECTION", + "PGVECTO_RS_DATABASE", + "PGVECTO_RS_HOST", + "PGVECTO_RS_PASSWORD", + "PGVECTO_RS_PORT", + "PGVECTO_RS_USER", + "RESPECT_XFORWARD_HEADERS_ENABLED", + "SCARF_NO_ANALYTICS", + "SSRF_DEFAULT_CONNECT_TIME_OUT", + "SSRF_DEFAULT_MAX_RETRIES", + "SSRF_DEFAULT_READ_TIME_OUT", + "SSRF_DEFAULT_TIME_OUT", + "SSRF_DEFAULT_WRITE_TIME_OUT", + "STORAGE_OPENDAL_SCHEME", + "SUPABASE_API_KEY", + "SUPABASE_BUCKET_NAME", + "SUPABASE_URL", + "USING_UGC_INDEX", + "VIKINGDB_CONNECTION_TIMEOUT", + "VIKINGDB_SOCKET_TIMEOUT", + "WEAVIATE_BATCH_SIZE", + "WEAVIATE_GRPC_ENABLED", +} + +API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys()) +DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys()) +DOCKER_COMPOSE_CONFIG_SET = set() + +with open(Path("docker") / Path("docker-compose.yaml")) as f: + DOCKER_COMPOSE_CONFIG_SET = set(yaml.safe_load(f.read())["x-shared-env"].keys()) + + +def test_yaml_config(): + # python set == operator is used to compare two sets + DIFF_API_WITH_DOCKER = ( + API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF + ) + if DIFF_API_WITH_DOCKER: + print( + f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}" + ) + raise Exception("API and Docker config sets are different") + DIFF_API_WITH_DOCKER_COMPOSE = ( + API_CONFIG_SET + - DOCKER_COMPOSE_CONFIG_SET + - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF + ) + if DIFF_API_WITH_DOCKER_COMPOSE: + print( + f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}" + ) + raise Exception("API and Docker Compose config sets are different") + print("All tests passed!") + + +if __name__ == "__main__": + test_yaml_config() From b5c2785e108e40eb43775e10e88714ab23705128 Mon Sep 17 00:00:00 2001 From: yihong Date: Wed, 18 Dec 2024 20:17:10 +0800 Subject: [PATCH 02/91] ci: fix config ci and it works (#11807) Signed-off-by: yihong0618 --- docker/.env.example | 2 ++ docker/docker-compose.yaml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docker/.env.example b/docker/.env.example index 0a5dffc57057f2..e5bddf7ae13dc8 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -107,6 +107,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60 # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. APP_MAX_ACTIVE_REQUESTS=0 +APP_MAX_EXECUTION_TIME=1200 # ------------------------------ # Container Startup Related Configuration @@ -606,6 +607,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Sentry Configuration # Used for application monitoring and error log tracking. # ------------------------------ +SENTRY_DSN= # API Service Sentry DSN address, default is empty, when empty, # all monitoring information is not reported to Sentry. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3c5e6d3a4d94f8..75f53f5ec3d022 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -18,6 +18,7 @@ x-shared-env: &shared-api-worker-env LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} + SENTRY_DSN: ${SENTRY_DSN:-} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} @@ -28,6 +29,7 @@ x-shared-env: &shared-api-worker-env FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} + APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_PORT: ${DIFY_PORT:-5001} SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} From 2624a6dcd0b89dd7c1aac2d7bfe7f769e9e3c992 Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Wed, 18 Dec 2024 21:24:21 +0800 Subject: [PATCH 03/91] Fix explore app icon (#11808) Co-authored-by: luowei Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/controllers/console/explore/recommended_app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index baf1f591b947e1..ce85f495aacd50 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -13,6 +13,7 @@ "name": fields.String, "mode": fields.String, "icon": fields.String, + "icon_type": fields.String, "icon_url": AppIconUrlField, "icon_background": fields.String, } From 3388d6636cb5aab1f1d18aed8a7be1df9b2e463a Mon Sep 17 00:00:00 2001 From: Agung Besti <35904444+agungbesti@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:36:11 +0700 Subject: [PATCH 04/91] add-model-azure-gpt-4o-2024-11-20 (#11803) Co-authored-by: agungbesti --- .../model_providers/azure_openai/_constant.py | 76 +++++++++++++++++++ .../azure_openai/azure_openai.yaml | 6 ++ 2 files changed, 82 insertions(+) diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 4cf58275d79fe3..3bd6375aa9117b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -819,6 +819,82 @@ class AzureBaseModel(BaseModel): ), ), ), + AzureBaseModel( + base_model_name="gpt-4o-2024-11-20", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name="top_p", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name="presence_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name="frequency_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=16384), + ParameterRule( + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object", "json_schema"], + ), + ParameterRule( + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", + help=I18nObject( + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", + ), + required=False, + ), + ], + pricing=PriceConfig( + input=5.00, + output=15.00, + unit=0.000001, + currency="USD", + ), + ), + ), AzureBaseModel( base_model_name="gpt-4-turbo", entity=AIModelEntity( diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 66c523504ef98f..a6ae47b28e5906 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -171,6 +171,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: gpt-4o-2024-11-20 + value: gpt-4o-2024-11-20 + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4-turbo value: gpt-4-turbo From 560d375e0f3f1404cde54a012ad7602e9b8cca53 Mon Sep 17 00:00:00 2001 From: sino Date: Thu, 19 Dec 2024 17:49:31 +0800 Subject: [PATCH 05/91] feat(ark): add doubao-pro-256k and doubao-embedding-large (#11831) --- .../model_providers/volcengine_maas/llm/models.py | 4 ++++ .../volcengine_maas/text_embedding/models.py | 3 ++- .../volcengine_maas/volcengine_maas.yaml | 12 ++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index f7698f9443f360..cf3cf23cfb9cef 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -40,6 +40,10 @@ class ModelConfig(BaseModel): properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL], ), + "Doubao-pro-256k": ModelConfig( + properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT), + features=[], + ), "Doubao-pro-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL], diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 4a6f5b6f7bc7cd..be9bba5f2450d2 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -12,6 +12,7 @@ class ModelConfig(BaseModel): ModelConfigs = { "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), + "Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } @@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig: if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get("context_size", 0)), + context_size=int(credentials.get("context_size", 4096)), max_chunks=int(credentials.get("max_chunks", 1)), ) ) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index 57492b6d9f2fd7..2ddb612546690c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -166,6 +166,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: Doubao-pro-256k + value: Doubao-pro-256k + show_on: + - variable: __model_type + value: llm - label: en_US: Llama3-8B value: Llama3-8B @@ -220,6 +226,12 @@ model_credential_schema: show_on: - variable: __model_type value: text-embedding + - label: + en_US: Doubao-embedding-large + value: Doubao-embedding-large + show_on: + - variable: __model_type + value: text-embedding - label: en_US: Custom zh_Hans: 自定义 From d0570675436e314c335d1350010e074ef90dffb6 Mon Sep 17 00:00:00 2001 From: barabicu Date: Thu, 19 Dec 2024 19:30:51 +0900 Subject: [PATCH 06/91] fix: remove ruff ignore SIM300 (#11810) --- api/.ruff.toml | 1 - api/controllers/console/admin.py | 2 +- .../replicate/text_embedding/text_embedding.py | 2 +- .../model_runtime/__mock/xinference.py | 8 ++++---- .../integration_tests/tools/api_tool/test_api_tool.py | 10 +++++----- .../integration_tests/workflow/nodes/test_http.py | 2 +- 6 files changed, 12 insertions(+), 13 deletions(-) diff --git a/api/.ruff.toml b/api/.ruff.toml index 0f3185223c1596..26a1b977a9f6ac 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -70,7 +70,6 @@ ignore = [ "SIM113", # eumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false - "SIM300", # yoda-conditions, ] [lint.per-file-ignores] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a70c4a31c7db94..8c0bf8710d3964 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -31,7 +31,7 @@ def decorated(*args, **kwargs): if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if dify_config.ADMIN_API_KEY != auth_token: + if auth_token != dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") return view(*args, **kwargs) diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index c4e9d0b9c6ceb2..41759fe07d0cac 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -119,7 +119,7 @@ def _generate_embeddings_by_text_input_key( embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif "texts" == text_input_key: + elif text_input_key == "texts": result = client.run( replicate_model_version, input={ diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 5f7dad50c10f11..794f4b0585632e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -21,13 +21,13 @@ def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHa if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): raise RuntimeError("404 Not Found") - if "generate" == model_uid: + if model_uid == "generate": return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if "chat" == model_uid: + if model_uid == "chat": return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if "embedding" == model_uid: + if model_uid == "embedding": return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if "rerank" == model_uid: + if model_uid == "rerank": return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) raise RuntimeError("404 Not Found") diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index 09729a961eff33..1bd75b91f745d6 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -34,9 +34,9 @@ def test_api_tool(setup_http_mock): response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert "/p_param" == response.request.url.path - assert b"query_param=q_param" == response.request.url.query - assert "h_param" == response.request.headers.get("header_param") - assert "application/json" == response.request.headers.get("content-type") - assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert response.request.url.path == "/p_param" + assert response.request.url.query == b"query_param=q_param" + assert response.request.headers.get("header_param") == "h_param" + assert response.request.headers.get("content-type") == "application/json" + assert response.request.headers.get("cookie") == "cookie_param=c_param" assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 9eea63f722e51f..0507fc707564dd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -384,7 +384,7 @@ def test_mock_404(setup_http_mock): assert result.outputs is not None resp = result.outputs - assert 404 == resp.get("status_code") + assert resp.get("status_code") == 404 assert "Not Found" in resp.get("body", "") From 12d45e9114cddf19767cbf5e267b5dd89f59792e Mon Sep 17 00:00:00 2001 From: yihong Date: Thu, 19 Dec 2024 20:50:09 +0800 Subject: [PATCH 07/91] fix: silicon change its model fix #11844 (#11847) Signed-off-by: yihong0618 --- .../model_runtime/model_providers/siliconflow/siliconflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index e121ab8c7e4e2f..03c4306144a651 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -18,7 +18,7 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: From 5a8a901560228a1556611ecaa67bcf5ddac22ce5 Mon Sep 17 00:00:00 2001 From: yihong Date: Thu, 19 Dec 2024 20:50:20 +0800 Subject: [PATCH 08/91] fix: float values are not json for nan value close #11827 (#11840) Signed-off-by: yihong0618 --- api/core/rag/embedding/cached_embedding.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index fc8e0440c332c3..652f7e145fd94d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -65,6 +65,11 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: for vector in embedding_result.embeddings: try: normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning(f"Normalized embedding is nan: {normalized_embedding}") + continue embedding_queue_embeddings.append(normalized_embedding) except IntegrityError: db.session.rollback() From de3911e930e859b7e3fdf6f0c1ab181999c826dd Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Thu, 19 Dec 2024 21:19:08 +0800 Subject: [PATCH 09/91] Fix/10584 wrong message when no custom tool available in custom tool list (#11851) --- web/app/components/tools/add-tool-modal/empty.tsx | 12 +++++++++--- web/i18n/en-US/tools.ts | 2 ++ web/i18n/zh-Hans/tools.ts | 2 ++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/web/app/components/tools/add-tool-modal/empty.tsx b/web/app/components/tools/add-tool-modal/empty.tsx index 051ae446d4cb0f..794d2a585c1379 100644 --- a/web/app/components/tools/add-tool-modal/empty.tsx +++ b/web/app/components/tools/add-tool-modal/empty.tsx @@ -1,13 +1,19 @@ +'use client' +import { useSearchParams } from 'next/navigation' import { useTranslation } from 'react-i18next' - const Empty = () => { const { t } = useTranslation() + const searchParams = useSearchParams() return (
-
{t('tools.addToolModal.emptyTitle')}
-
{t('tools.addToolModal.emptyTip')}
+
+ {t(`tools.addToolModal.${searchParams.get('category') === 'workflow' ? 'emptyTitle' : 'emptyTitleCustom'}`)} +
+
+ {t(`tools.addToolModal.${searchParams.get('category') === 'workflow' ? 'emptyTip' : 'emptyTipCustom'}`)} +
) } diff --git a/web/i18n/en-US/tools.ts b/web/i18n/en-US/tools.ts index f96ae8144e0278..b1f278f9cea2a9 100644 --- a/web/i18n/en-US/tools.ts +++ b/web/i18n/en-US/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Manage in Tools', emptyTitle: 'No workflow tool available', emptyTip: 'Go to "Workflow -> Publish as Tool"', + emptyTitleCustom: 'No custom tool available', + emptyTipCustom: 'Create a custom tool', }, createTool: { title: 'Create Custom Tool', diff --git a/web/i18n/zh-Hans/tools.ts b/web/i18n/zh-Hans/tools.ts index 1473fc23d38fca..a788ef0abee445 100644 --- a/web/i18n/zh-Hans/tools.ts +++ b/web/i18n/zh-Hans/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: '去工具列表管理', emptyTitle: '没有可用的工作流工具', emptyTip: '去 “工作流 -> 发布为工具” 添加', + emptyTitleCustom: '没有可用的自定义工具', + emptyTipCustom: '创建自定义工具', }, createTool: { title: '创建自定义工具', From 15485010505c46f2195a4971817cbeafaa85d024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=BB=E7=AC=91zz?= <43721571+shaxiaozz@users.noreply.github.com> Date: Thu, 19 Dec 2024 23:05:27 +0800 Subject: [PATCH 10/91] fix: comfyui tool supports https (#11823) --- api/core/tools/provider/builtin/comfyui/comfyui.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py index bab690af8292b7..a8127dd23f1553 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.py +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: ws = websocket.WebSocket() base_url = URL(credentials.get("base_url")) - ws_address = f"ws://{base_url.authority}/ws?clientId=test123" + ws_protocol = "ws" + if base_url.scheme == "https": + ws_protocol = "wss" + ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123" try: ws.connect(ws_address) From 44104797d662e64775405518d29bfb7807bdb349 Mon Sep 17 00:00:00 2001 From: stardust <642720202@qq.com> Date: Fri, 20 Dec 2024 02:21:41 +0800 Subject: [PATCH 11/91] fix: Enhance file type detection in HTTP Request node (#11797) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: -LAN- Co-authored-by: 谭成 Co-authored-by: -LAN- --- api/core/workflow/nodes/http_request/node.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index d040cc9f559ded..a0fc8acaefaaec 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,4 +1,5 @@ import logging +import mimetypes from collections.abc import Mapping, Sequence from typing import Any @@ -156,20 +157,24 @@ def _extract_variable_selector_to_variable_mapping( def extract_files(self, url: str, response: Response) -> list[File]: """ - Extract files from response + Extract files from response by checking both Content-Type header and URL """ files = [] is_file = response.is_file content_type = response.content_type content = response.content - if is_file and content_type: + if is_file: + # Guess file extension from URL or Content-Type header + filename = url.split("?")[0].split("/")[-1] or "" + mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, file_binary=content, - mimetype=content_type, + mimetype=mime_type, ) mapping = { From 9d93ad1f16daa019890d278395a8c004f1ae07d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 20 Dec 2024 09:26:31 +0800 Subject: [PATCH 12/91] feat: add gemini-2.0-flash-thinking-exp-1219 (#11863) --- .../model_providers/google/llm/_position.yaml | 1 + .../gemini-2.0-flash-thinking-exp-1219.yaml | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml diff --git a/api/core/model_runtime/model_providers/google/llm/_position.yaml b/api/core/model_runtime/model_providers/google/llm/_position.yaml index 0b5e1025bfeb83..4ad0670e119af6 100644 --- a/api/core/model_runtime/model_providers/google/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/google/llm/_position.yaml @@ -1,4 +1,5 @@ - gemini-2.0-flash-exp +- gemini-2.0-flash-thinking-exp-1219 - gemini-1.5-pro - gemini-1.5-pro-latest - gemini-1.5-pro-001 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml new file mode 100644 index 00000000000000..dfcf8fd050ef06 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-thinking-exp-1219.yaml @@ -0,0 +1,39 @@ +model: gemini-2.0-flash-thinking-exp-1219 +label: + en_US: Gemini 2.0 Flash Thinking Exp 1219 +model_type: llm +features: + - agent-thought + - vision + - document + - video + - audio +model_properties: + mode: chat + context_size: 32767 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD From 95a7e5013707277c8d7440cf20f4b61e81b10a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=BB=E7=AC=91zz?= <43721571+shaxiaozz@users.noreply.github.com> Date: Fri, 20 Dec 2024 09:27:21 +0800 Subject: [PATCH 13/91] Fix comfyui tool https (#11859) --- .../tools/provider/builtin/comfyui/tools/comfyui_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index bed9cd1882fa29..f994cdbf66e78b 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -40,7 +40,10 @@ def queue_prompt(self, client_id: str, prompt: dict) -> str: def open_websocket_connection(self) -> tuple[WebSocket, str]: client_id = str(uuid.uuid4()) ws = WebSocket() - ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" + ws_protocol = "ws" + if self.base_url.scheme == "https": + ws_protocol = "wss" + ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}" ws.connect(ws_address) return ws, client_id From 463fbe268047520fba99b60c123a88a4c5141884 Mon Sep 17 00:00:00 2001 From: yihong Date: Fri, 20 Dec 2024 09:28:32 +0800 Subject: [PATCH 14/91] fix: better gard nan value from numpy for issue #11827 (#11864) Signed-off-by: yihong0618 --- .../azure_openai/text_embedding/text_embedding.py | 5 ++++- .../model_providers/cohere/text_embedding/text_embedding.py | 5 ++++- .../model_providers/openai/text_embedding/text_embedding.py | 5 ++++- .../model_providers/upstage/text_embedding/text_embedding.py | 5 ++++- api/core/rag/embedding/cached_embedding.py | 2 ++ 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index c45ce87ea76838..69d2cfaded453f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -92,7 +92,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 5fd4d637be7643..9e4df2706080f9 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -88,7 +88,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index bec01fe6797f52..9c8c8d5882ee4e 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -97,7 +97,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding # calc usage usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 7dd495b55ef4e6..5b340e53bbc543 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -100,7 +100,10 @@ def _invoke( average = embeddings_batch[0] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embedding = (average / np.linalg.norm(average)).tolist() + if np.isnan(embedding).any(): + raise ValueError("Normalized embedding is nan please try again") + embeddings[i] = embedding usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 652f7e145fd94d..8ddda7e9832d97 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -116,6 +116,8 @@ def embed_query(self, text: str) -> list[float]: embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") except Exception as ex: if dify_config.DEBUG: logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") From bb2f46d7cc9ecf92e297d5655c5514429b72a693 Mon Sep 17 00:00:00 2001 From: "Dr.MerdanBay" <110794035+KMerdan@users.noreply.github.com> Date: Fri, 20 Dec 2024 12:13:39 +0900 Subject: [PATCH 15/91] fix: add safe dictionary access for bedrock credentials (#11860) --- .../bedrock/get_bedrock_client.py | 16 ++++++++++++---- .../model_providers/bedrock/rerank/rerank.py | 5 ++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index a19ffbb20a6a9e..2ad37cef3b38f1 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -1,11 +1,19 @@ +from collections.abc import Mapping + import boto3 from botocore.config import Config +from core.model_runtime.errors.invoke import InvokeBadRequestError + + +def get_bedrock_client(service_name: str, credentials: Mapping[str, str]): + region_name = credentials.get("aws_region") + if not region_name: + raise InvokeBadRequestError("aws_region is required") + client_config = Config(region_name=region_name) + aws_access_key_id = credentials.get("aws_access_key_id") + aws_secret_access_key = credentials.get("aws_secret_access_key") -def get_bedrock_client(service_name, credentials=None): - client_config = Config(region_name=credentials["aws_region"]) - aws_access_key_id = credentials["aws_access_key_id"] - aws_secret_access_key = credentials["aws_secret_access_key"] if aws_access_key_id and aws_secret_access_key: # use aksk to call bedrock client = boto3.client( diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index e134db646f3d39..9da23ba1b0f08f 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -62,7 +62,10 @@ def _invoke( } ) modelId = model - region = credentials["aws_region"] + region = credentials.get("aws_region") + # region is a required field + if not region: + raise InvokeBadRequestError("aws_region is required in credentials") model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}" rerankingConfiguration = { "type": "BEDROCK_RERANKING_MODEL", From 2d186e1e76613442b9da67be2c747bab989ff67d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 13:46:54 +0800 Subject: [PATCH 16/91] chore(deps): bump nanoid from 3.3.7 to 3.3.8 in /web (#11876) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/yarn.lock b/web/yarn.lock index 6389f985c74445..035b9f3867d3a9 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -9882,9 +9882,9 @@ nan@^2.17.0: integrity sha512-nbajikzWTMwsW+eSsNm3QwlOs7het9gGJU5dDZzRTQGk03vyBOauxgI4VakDzE0PtsGTmXPsXTbbjVhRwR5mpw== nanoid@^3.3.6, nanoid@^3.3.7: - version "3.3.7" - resolved "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz" - integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g== + version "3.3.8" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.8.tgz#b1be3030bee36aaff18bacb375e5cce521684baf" + integrity sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w== natural-compare-lite@^1.4.0: version "1.4.0" From 3599751f9379721dd9f87309838c2d43d56cd263 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 20 Dec 2024 14:12:29 +0800 Subject: [PATCH 17/91] chore(db): use a better way to export models and remove unused table (#11838) Signed-off-by: -LAN- --- api/controllers/console/app/wraps.py | 3 +- api/core/helper/encrypter.py | 2 +- api/extensions/ext_database.py | 15 +- api/extensions/ext_import_modules.py | 1 - api/libs/helper.py | 2 +- ...b07f66c737_remove_unused_tool_providers.py | 39 +++++ api/models/__init__.py | 142 +++++++++++++++++- api/models/account.py | 3 +- api/models/api_based_extension.py | 3 +- api/models/dataset.py | 2 +- api/models/engine.py | 13 ++ api/models/model.py | 2 +- api/models/provider.py | 3 +- api/models/source.py | 3 +- api/models/task.py | 2 +- api/models/tool.py | 47 ------ api/models/tools.py | 8 +- api/models/web.py | 3 +- api/models/workflow.py | 4 +- 19 files changed, 206 insertions(+), 91 deletions(-) create mode 100644 api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py create mode 100644 api/models/engine.py delete mode 100644 api/models/tool.py diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c71ee8e5dfea1d..63edb83079041e 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,8 +5,7 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models import App -from models.model import AppMode +from models import App, AppMode def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 96341a1b780a80..744fce1cf99cfe 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -1,6 +1,5 @@ import base64 -from extensions.ext_database import db from libs import rsa @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + from models.engine import db if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f"Tenant with id {tenant_id} not found") diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index e293afa1115e8b..93842a303683bb 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -1,18 +1,5 @@ -from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import MetaData - from dify_app import DifyApp - -POSTGRES_INDEXES_NAMING_CONVENTION = { - "ix": "%(column_0_label)s_idx", - "uq": "%(table_name)s_%(column_0_name)s_key", - "ck": "%(table_name)s_%(constraint_name)s_check", - "fk": "%(table_name)s_%(column_0_name)s_fkey", - "pk": "%(table_name)s_pkey", -} - -metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) -db = SQLAlchemy(metadata=metadata) +from models import db def init_app(app: DifyApp): diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py index eefdfd38236662..9566f430b647fe 100644 --- a/api/extensions/ext_import_modules.py +++ b/api/extensions/ext_import_modules.py @@ -3,4 +3,3 @@ def init_app(app: DifyApp): from events import event_handlers # noqa: F401 - from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 diff --git a/api/libs/helper.py b/api/libs/helper.py index 026ded3506812d..91b1d1fe173d6f 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -13,7 +13,7 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields # type: ignore +from flask_restful import fields from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator diff --git a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py new file mode 100644 index 00000000000000..881a9e3c1e06b6 --- /dev/null +++ b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py @@ -0,0 +1,39 @@ +"""remove unused tool_providers + +Revision ID: 11b07f66c737 +Revises: cf8f4fc45278 +Create Date: 2024-12-19 17:46:25.780116 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '11b07f66c737' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_providers') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 61a38870cf9b91..b0b9880ca42a9d 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,53 +1,187 @@ -from .account import Account, AccountIntegrate, InvitationCode, Tenant -from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .account import ( + Account, + AccountIntegrate, + AccountStatus, + InvitationCode, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) +from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .dataset import ( + AppDatasetJoin, + Dataset, + DatasetCollectionBinding, + DatasetKeywordTable, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, + Embedding, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, + TidbAuthBinding, + Whitelist, +) +from .engine import db +from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( + ApiRequest, ApiToken, App, + AppAnnotationHitHistory, + AppAnnotationSetting, AppMode, + AppModelConfig, Conversation, + DatasetRetrieverResource, + DifySetup, EndUser, + IconType, InstalledApp, Message, + MessageAgentThought, MessageAnnotation, + MessageChain, + MessageFeedback, MessageFile, + OperationLog, RecommendedApp, Site, + Tag, + TagBinding, + TraceAppConfig, UploadFile, ) -from .source import DataSourceOauthBinding -from .tools import ToolFile +from .provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderOrder, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) +from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from .task import CeleryTask, CeleryTaskSet +from .tools import ( + ApiToolProvider, + BuiltinToolProvider, + PublishedAppTool, + ToolConversationVariables, + ToolFile, + ToolLabelBinding, + ToolModelInvoke, + WorkflowToolProvider, +) +from .web import PinnedConversation, SavedMessage from .workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, WorkflowRun, + WorkflowRunStatus, + WorkflowType, ) __all__ = [ + "APIBasedExtension", + "APIBasedExtensionPoint", "Account", "AccountIntegrate", + "AccountStatus", + "ApiRequest", "ApiToken", + "ApiToolProvider", # Added "App", + "AppAnnotationHitHistory", + "AppAnnotationSetting", + "AppDatasetJoin", "AppMode", + "AppModelConfig", + "BuiltinToolProvider", # Added + "CeleryTask", + "CeleryTaskSet", "Conversation", "ConversationVariable", + "CreatedByRole", + "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", + "DatasetCollectionBinding", + "DatasetKeywordTable", + "DatasetPermission", + "DatasetPermissionEnum", "DatasetProcessRule", + "DatasetQuery", + "DatasetRetrieverResource", + "DifySetup", "Document", "DocumentSegment", + "Embedding", "EndUser", + "ExternalKnowledgeApis", + "ExternalKnowledgeBindings", + "IconType", "InstalledApp", "InvitationCode", + "LoadBalancingModelConfig", "Message", + "MessageAgentThought", "MessageAnnotation", + "MessageChain", + "MessageFeedback", "MessageFile", + "OperationLog", + "PinnedConversation", + "Provider", + "ProviderModel", + "ProviderModelSetting", + "ProviderOrder", + "ProviderQuotaType", + "ProviderType", + "PublishedAppTool", "RecommendedApp", + "SavedMessage", "Site", + "Tag", + "TagBinding", "Tenant", + "TenantAccountJoin", + "TenantAccountJoinRole", + "TenantAccountRole", + "TenantDefaultModel", + "TenantPreferredModelProvider", + "TenantStatus", + "TidbAuthBinding", + "ToolConversationVariables", "ToolFile", + "ToolLabelBinding", + "ToolModelInvoke", + "TraceAppConfig", "UploadFile", + "UserFrom", + "Whitelist", "Workflow", "WorkflowAppLog", + "WorkflowAppLogCreatedFrom", + "WorkflowNodeExecution", + "WorkflowNodeExecutionStatus", + "WorkflowNodeExecutionTriggeredFrom", "WorkflowRun", + "WorkflowRunStatus", + "WorkflowRunTriggeredFrom", + "WorkflowToolProvider", + "WorkflowType", + "db", ] diff --git a/api/models/account.py b/api/models/account.py index 951e836dec1873..ce17b90def3fa8 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,8 +3,7 @@ from flask_login import UserMixin -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 97173747afc4b1..4d4182cabdc163 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,6 @@ import enum -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/dataset.py b/api/models/dataset.py index 8ab957e875a1bf..97e4d6c0ef8981 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -15,10 +15,10 @@ from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db from extensions.ext_storage import storage from .account import Account +from .engine import db from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID diff --git a/api/models/engine.py b/api/models/engine.py new file mode 100644 index 00000000000000..dda93bc9415cfc --- /dev/null +++ b/api/models/engine.py @@ -0,0 +1,13 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData + +POSTGRES_INDEXES_NAMING_CONVENTION = { + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", +} + +metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) +db = SQLAlchemy(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 03b8e0bea553aa..54b719628c200b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -16,11 +16,11 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from extensions.ext_database import db from libs.helper import generate_string from models.enums import CreatedByRole from .account import Account, Tenant +from .engine import db from .types import StringUUID diff --git a/api/models/provider.py b/api/models/provider.py index 644915e781084b..65f70b76e9c362 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,6 @@ from enum import Enum -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/source.py b/api/models/source.py index 07695f06e6cf00..4d98572ef81fcb 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -2,8 +2,7 @@ from sqlalchemy.dialects.postgresql import JSONB -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/task.py b/api/models/task.py index 5d89ff85acc781..27571e24746fe7 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -2,7 +2,7 @@ from celery import states -from extensions.ext_database import db +from .engine import db class CeleryTask(db.Model): diff --git a/api/models/tool.py b/api/models/tool.py deleted file mode 100644 index a81bb65174a724..00000000000000 --- a/api/models/tool.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -from enum import Enum - -from extensions.ext_database import db - -from .types import StringUUID - - -class ToolProviderName(Enum): - SERPAPI = "serpapi" - - @staticmethod - def value_of(value): - for member in ToolProviderName: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class ToolProvider(db.Model): - __tablename__ = "tool_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), - db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), - ) - - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - tool_name = db.Column(db.String(40), nullable=False) - encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - - @property - def credentials_is_set(self): - """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ - return self.encrypted_credentials is not None - - @property - def credentials(self): - """ - Returns the decrypted config. - """ - return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None diff --git a/api/models/tools.py b/api/models/tools.py index 4040339e026474..c390be4625aab2 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from extensions.ext_database import db +from .engine import db from .model import Account, App, Tenant from .types import StringUUID @@ -82,7 +82,7 @@ def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) @property - def app(self) -> App: + def app(self): return db.session.query(App).filter(App.id == self.app_id).first() @@ -201,10 +201,6 @@ class WorkflowToolProvider(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - @property - def schema_type(self) -> ApiProviderSchemaType: - return ApiProviderSchemaType.value_of(self.schema_type_str) - @property def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() diff --git a/api/models/web.py b/api/models/web.py index bc088c185d5a8b..a0f87cf4566f04 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,5 +1,4 @@ -from extensions.ext_database import db - +from .engine import db from .model import Message from .types import StringUUID diff --git a/api/models/workflow.py b/api/models/workflow.py index 09e3728d7c2445..4ddf2e9082f10c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,12 +12,12 @@ from constants import HIDDEN_VALUE from core.helper import encrypter from core.variables import SecretVariable, Variable -from extensions.ext_database import db from factories import variable_factory from libs import helper from models.enums import CreatedByRole from .account import Account +from .engine import db from .types import StringUUID @@ -399,7 +399,7 @@ class WorkflowRun(db.Model): graph = db.Column(db.Text) inputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded - outputs: Mapped[str] = mapped_column(sa.Text, default="{}") + outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) From 996a9135f6b3244a68e4ec2b3da67c24f9a928bd Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 20 Dec 2024 14:12:50 +0800 Subject: [PATCH 18/91] feat(llm_node): support order in text and files (#11837) Signed-off-by: -LAN- --- api/core/file/file_manager.py | 28 ++-- api/core/file/file_repository.py | 32 ---- api/core/file/models.py | 32 ++++ api/core/model_runtime/entities/__init__.py | 2 + .../entities/message_entities.py | 8 +- api/core/workflow/nodes/llm/entities.py | 1 + api/core/workflow/nodes/llm/node.py | 145 +++++++++--------- .../question_classifier_node.py | 4 +- api/factories/file_factory.py | 3 + api/models/model.py | 37 ++++- .../prompt/test_advanced_prompt_transform.py | 1 + api/tests/unit_tests/core/test_file.py | 30 +--- .../http_request/test_http_request_node.py | 2 + .../core/workflow/nodes/llm/test_node.py | 55 +++++-- .../core/workflow/nodes/llm/test_scenarios.py | 4 +- .../core/workflow/nodes/test_if_else.py | 1 + .../core/workflow/nodes/test_list_operator.py | 6 + .../core/workflow/test_variable_pool.py | 1 + 18 files changed, 217 insertions(+), 175 deletions(-) delete mode 100644 api/core/file/file_repository.py diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index dd363ee1baeb8b..15eb351a7ef309 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,15 +1,14 @@ import base64 from configs import dify_config -from core.file import file_repository from core.helper import ssrf_proxy from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + MultiModalPromptMessageContent, VideoPromptMessageContent, ) -from extensions.ext_database import db from extensions.ext_storage import storage from . import helpers @@ -41,7 +40,7 @@ def to_prompt_message_content( /, *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -): +) -> MultiModalPromptMessageContent: if f.extension is None: raise ValueError("Missing file extension") if f.mime_type is None: @@ -70,16 +69,13 @@ def to_prompt_message_content( def download(f: File, /): - if f.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file = file_repository.get_tool_file(session=db.session(), file=f) - return _download_file_content(tool_file.file_key) - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - return _download_file_content(upload_file.key) - # remote file - response = ssrf_proxy.get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content + if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): + return _download_file_content(f._storage_key) + elif f.transfer_method == FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + return response.content + raise ValueError(f"unsupported transfer method: {f.transfer_method}") def _download_file_content(path: str, /): @@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - data = _download_file_content(upload_file.key) + data = _download_file_content(f._storage_key) case FileTransferMethod.TOOL_FILE: - tool_file = file_repository.get_tool_file(session=db.session(), file=f) - data = _download_file_content(tool_file.file_key) + data = _download_file_content(f._storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py deleted file mode 100644 index 975e1e72db0e0a..00000000000000 --- a/api/core/file/file_repository.py +++ /dev/null @@ -1,32 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from models import ToolFile, UploadFile - -from .models import File - - -def get_upload_file(*, session: Session, file: File): - if file.related_id is None: - raise ValueError("Missing file related_id") - stmt = select(UploadFile).filter( - UploadFile.id == file.related_id, - UploadFile.tenant_id == file.tenant_id, - ) - record = session.scalar(stmt) - if not record: - raise ValueError(f"upload file {file.related_id} not found") - return record - - -def get_tool_file(*, session: Session, file: File): - if file.related_id is None: - raise ValueError("Missing file related_id") - stmt = select(ToolFile).filter( - ToolFile.id == file.related_id, - ToolFile.tenant_id == file.tenant_id, - ) - record = session.scalar(stmt) - if not record: - raise ValueError(f"tool file {file.related_id} not found") - return record diff --git a/api/core/file/models.py b/api/core/file/models.py index 3e7e189c62cade..4b4674da095d34 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -47,6 +47,38 @@ class File(BaseModel): mime_type: Optional[str] = None size: int = -1 + # Those properties are private, should not be exposed to the outside. + _storage_key: str + + def __init__( + self, + *, + id: Optional[str] = None, + tenant_id: str, + type: FileType, + transfer_method: FileTransferMethod, + remote_url: Optional[str] = None, + related_id: Optional[str] = None, + filename: Optional[str] = None, + extension: Optional[str] = None, + mime_type: Optional[str] = None, + size: int = -1, + storage_key: str, + ): + super().__init__( + id=id, + tenant_id=tenant_id, + type=type, + transfer_method=transfer_method, + remote_url=remote_url, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + self._storage_key = storage_key + def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index 1c73755cffd62e..c3e1351e3b6d26 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -4,6 +4,7 @@ AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + MultiModalPromptMessageContent, PromptMessage, PromptMessageContent, PromptMessageContentType, @@ -27,6 +28,7 @@ "LLMResultChunkDelta", "LLMUsage", "ModelPropertyKey", + "MultiModalPromptMessageContent", "PromptMessage", "PromptMessage", "PromptMessageContent", diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 7928426106ccd5..0efe46f87d6de9 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent): """ type: PromptMessageContentType - format: str = Field(..., description="the format of multi-modal file") - base64_data: str = Field("", description="the base64 data of multi-modal file") - url: str = Field("", description="the url of multi-modal file") - mime_type: str = Field(..., description="the mime type of multi-modal file") + format: str = Field(default=..., description="the format of multi-modal file") + base64_data: str = Field(default="", description="the base64 data of multi-modal file") + url: str = Field(default="", description="the url of multi-modal file") + mime_type: str = Field(default=..., description="the mime type of multi-modal file") @computed_field(return_type=str) @property diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 19a66087f7d175..505068104c2c2d 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -50,6 +50,7 @@ def convert_none_jinja2_variables(cls, v: Any): class LLMNodeChatModelMessage(ChatModelMessage): + text: str = "" jinja2_text: Optional[str] = None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 67e62cb8750430..bcba3f5a4d55f1 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -145,8 +145,8 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] query = query_variable.text prompt_messages, stop = self._fetch_prompt_messages( - user_query=query, - user_files=files, + sys_query=query, + sys_files=files, context=context, memory=memory, model_config=model_config, @@ -545,8 +545,8 @@ def _fetch_memory( def _fetch_prompt_messages( self, *, - user_query: str | None = None, - user_files: Sequence["File"], + sys_query: str | None = None, + sys_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -562,7 +562,7 @@ def _fetch_prompt_messages( if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - _handle_list_messages( + self._handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -581,14 +581,14 @@ def _fetch_prompt_messages( prompt_messages.extend(memory_messages) # Add current query to the prompt messages - if user_query: + if sys_query: message = LLMNodeChatModelMessage( - text=user_query, + text=sys_query, role=PromptMessageRole.USER, edition_type="basic", ) prompt_messages.extend( - _handle_list_messages( + self._handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -635,24 +635,27 @@ def _fetch_prompt_messages( raise ValueError("Invalid prompt content type") # Add current query to the prompt message - if user_query: + if sys_query: if prompt_content_type == str: - prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content elif prompt_content_type == list: for content_item in prompt_content: if content_item.type == PromptMessageContentType.TEXT: - content_item.data = user_query + "\n" + content_item.data + content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") else: raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - if vision_enabled and user_files: + # The sys_files will be deprecated later + if vision_enabled and sys_files: file_prompts = [] - for file in user_files: + for file in sys_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) @@ -662,7 +665,7 @@ def _fetch_prompt_messages( else: prompt_messages.append(UserPromptMessage(content=file_prompts)) - # Filter prompt messages + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message.content, list): @@ -846,6 +849,58 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: }, } + def _handle_list_messages( + self, + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, + ) -> Sequence[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + for message in messages: + contents: list[PromptMessageContent] = [] + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + contents.append(TextPromptMessageContent(data=result_text)) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + contents.append(file_content) + else: + plain_text = segment.markdown.strip() + if plain_text: + contents.append(TextPromptMessageContent(data=plain_text)) + prompt_message = _combine_message_content_with_role(contents=contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -880,68 +935,6 @@ def _render_jinja2_message( return result_text -def _handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, -) -> Sequence[PromptMessage]: - prompt_messages = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinjia2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - if isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages - - def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity ) -> int: diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 7594036b50abb4..15cfabe4780ee2 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,10 +86,10 @@ def _run(self): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - user_query=query, + sys_query=query, memory=memory, model_config=model_config, - user_files=files, + sys_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 8538775a67242b..01d95dcfb3d2fe 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -139,6 +139,7 @@ def _build_from_local_file( remote_url=row.source_url, related_id=mapping.get("upload_file_id"), size=row.size, + storage_key=row.key, ) @@ -168,6 +169,7 @@ def _build_from_remote_url( mime_type=mime_type, extension=extension, size=file_size, + storage_key="", ) @@ -220,6 +222,7 @@ def _build_from_tool_file( extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, + storage_key=tool_file.file_key, ) diff --git a/api/models/model.py b/api/models/model.py index 54b719628c200b..33f36af1e49445 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -560,13 +560,29 @@ class Conversation(db.Model): @property def inputs(self): inputs = self._inputs.copy() + + # Convert file mapping to File object for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - inputs[key] = File.model_validate(value) + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) elif isinstance(value, list) and all( isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value ): - inputs[key] = [File.model_validate(item) for item in value] + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + return inputs @inputs.setter @@ -758,12 +774,25 @@ class Message(db.Model): def inputs(self): inputs = self._inputs.copy() for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - inputs[key] = File.model_validate(value) + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) elif isinstance(value, list) and all( isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value ): - inputs[key] = [File.model_validate(item) for item in value] + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) return inputs @inputs.setter diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 941984863bba00..ee0f7672f8c814 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", + storage_key="", ) ] diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 4edbc01cc778e8..e02d882780900f 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,34 +1,9 @@ import json -from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig +from core.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow -def test_file_loads_and_dumps(): - file = File( - id="file1", - tenant_id="tenant1", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image1.jpg", - ) - - file_dict = file.model_dump() - assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY - assert file_dict["type"] == file.type.value - assert isinstance(file_dict["type"], str) - assert file_dict["transfer_method"] == file.transfer_method.value - assert isinstance(file_dict["transfer_method"], str) - assert "_extra_config" not in file_dict - - file_obj = File.model_validate(file_dict) - assert file_obj.id == file.id - assert file_obj.tenant_id == file.tenant_id - assert file_obj.type == file.type - assert file_obj.transfer_method == file.transfer_method - assert file_obj.remote_url == file.remote_url - - def test_file_to_dict(): file = File( id="file1", @@ -36,10 +11,11 @@ def test_file_to_dict(): type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", + storage_key="storage_key", ) file_dict = file.to_dict() - assert "_extra_config" not in file_dict + assert "_storage_key" not in file_dict assert "url" in file_dict diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 70ec023140e152..97bacada74572d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch): type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1111", + storage_key="", ), ), ) @@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch): type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1111", + storage_key="", ), ), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5f69a4f2056da3..76db42ef106dfa 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -21,7 +21,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute @@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node): filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + storage_key="", ) llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) @@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node): filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + storage_key="", ), File( id="2", @@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node): filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", + storage_key="", ), ] llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) @@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + storage_key="", ) ] fake_query = faker.sentence() prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=files, + sys_query=fake_query, + sys_files=files, context=None, memory=None, model_config=model_config, @@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): test_scenarios = [ LLMNodeTestScenario( description="No files", - user_query=fake_query, - user_files=[], + sys_query=fake_query, + sys_files=[], features=[], vision_enabled=False, vision_detail=None, @@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), LLMNodeTestScenario( description="User files", - user_query=fake_query, - user_files=[ + sys_query=fake_query, + sys_files=[ File( tenant_id="test", type=FileType.IMAGE, @@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): remote_url=fake_remote_url, extension=".jpg", mime_type="image/jpg", + storage_key="", ) ], vision_enabled=True, @@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), LLMNodeTestScenario( description="Prompt template with variable selector of File", - user_query=fake_query, - user_files=[], + sys_query=fake_query, + sys_files=[], vision_enabled=False, vision_detail=fake_vision_detail, features=[ModelFeature.VISION], @@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): remote_url=fake_remote_url, extension=".jpg", mime_type="image/jpg", + storage_key="", ) }, ), @@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=scenario.user_query, - user_files=scenario.user_files, + sys_query=scenario.sys_query, + sys_files=scenario.sys_files, context=fake_context, memory=memory, model_config=model_config, @@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): assert ( prompt_messages == scenario.expected_messages ), f"Message content mismatch in scenario: {scenario.description}" + + +def test_handle_list_messages_basic(llm_node): + messages = [ + LLMNodeChatModelMessage( + text="Hello, {#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ] + context = "world" + jinja2_variables = [] + variable_pool = llm_node.graph_runtime_state.variable_pool + vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH + + result = llm_node._handle_list_messages( + messages=messages, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail_config, + ) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == [TextPromptMessageContent(data="Hello, world")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index 8e39445baf5490..21bb857353262c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel): """Test scenario for LLM node testing.""" description: str = Field(..., description="Description of the test scenario") - user_query: str = Field(..., description="User query input") - user_files: Sequence[File] = Field(default_factory=list, description="List of user files") + sys_query: str = Field(..., description="User query input") + sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") vision_enabled: bool = Field(default=False, description="Whether vision is enabled") vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index d964d0e3529c69..41e2c5d48468f6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -248,6 +248,7 @@ def test_array_file_contains_file_name(): transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", + storage_key="", ), ], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d20dfc5b311698..36116d35404cf5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", + storage_key="", ), File( filename="document1.pdf", @@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", + storage_key="", ), File( filename="image2.png", @@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", + storage_key="", ), File( filename="audio1.mp3", @@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", + storage_key="", ), ] variable = ArrayFileSegment(value=files) @@ -130,6 +134,7 @@ def test_get_file_extract_string_func(): mime_type="text/plain", remote_url="https://example.com/test_file.txt", related_id="test_related_id", + storage_key="", ) # Test each case @@ -150,6 +155,7 @@ def test_get_file_extract_string_func(): mime_type=None, remote_url=None, related_id="test_related_id", + storage_key="", ) assert _get_file_extract_string_func(key="name")(empty_file) == "" diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9ea6acac17132d..efbcdc760c6995 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -19,6 +19,7 @@ def file(): related_id="test_related_id", remote_url="test_url", filename="test_file.txt", + storage_key="", ) From f6247fe67cda7d032d20406a5e04fe0c1c1aa9a4 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Fri, 20 Dec 2024 14:13:44 +0800 Subject: [PATCH 19/91] Feat: Add partial success status to the app log (#11869) Co-authored-by: Novice Lee --- api/fields/conversation_fields.py | 3 +- api/models/model.py | 24 +++++++++ web/app/components/app/log/list.tsx | 51 +++++++++++++++++--- web/app/components/app/workflow-log/list.tsx | 8 +++ 4 files changed, 79 insertions(+), 7 deletions(-) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 5bd21be80779a4..6a9e347b1e04b4 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -85,7 +85,7 @@ def format(self, value): } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} - +status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer} model_config_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -166,6 +166,7 @@ def format(self, value): "message_count": fields.Integer, "user_feedback_stats": fields.Nested(feedback_stat_fields), "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "status_count": fields.Nested(status_count_fields), } conversation_with_summary_pagination_fields = { diff --git a/api/models/model.py b/api/models/model.py index 33f36af1e49445..8608b12af1fc8e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -18,6 +18,7 @@ from core.file.tool_file_parser import ToolFileParser from libs.helper import generate_string from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus from .account import Account, Tenant from .engine import db @@ -695,6 +696,29 @@ def admin_feedback_stats(self): return {"like": like, "dislike": dislike} + @property + def status_count(self): + messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() + status_counts = { + WorkflowRunStatus.SUCCEEDED: 0, + WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.PARTIAL_SUCCESSED: 0, + } + + for message in messages: + if message.workflow_run: + status_counts[message.workflow_run.status] += 1 + + return ( + { + "success": status_counts[WorkflowRunStatus.SUCCEEDED], + "failed": status_counts[WorkflowRunStatus.FAILED], + "partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCESSED], + } + if messages + else None + ) + @property def first_message(self): return db.session.query(Message).filter(Message.conversation_id == self.id).first() diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 8f064c209efe10..383aeb1492c769 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -16,6 +16,7 @@ import { createContext, useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useTranslation } from 'react-i18next' import type { ChatItemInTree } from '../../base/chat/types' +import Indicator from '../../header/indicator' import VarPanel from './var-panel' import type { FeedbackFunc, FeedbackType, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type' import type { Annotation, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log' @@ -57,6 +58,12 @@ type IDrawerContext = { appDetail?: App } +type StatusCount = { + success: number + failed: number + partial_success: number +} + const DrawerContext = createContext({} as IDrawerContext) /** @@ -71,6 +78,33 @@ const HandThumbIconWithCount: FC<{ count: number; iconType: 'up' | 'down' }> = ( } +const statusTdRender = (statusCount: StatusCount) => { + if (statusCount.partial_success + statusCount.failed === 0) { + return ( +
+ + Success +
+ ) + } + else if (statusCount.failed === 0) { + return ( +
+ + Partial Success +
+ ) + } + else { + return ( +
+ + {statusCount.failed} {`${statusCount.failed > 1 ? 'Failures' : 'Failure'}`} +
+ ) + } +} + const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => { const newChatList: IChatItem[] = [] messages.forEach((item: ChatMessage) => { @@ -496,8 +530,8 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { } /** - * Text App Conversation Detail Component - */ + * Text App Conversation Detail Component + */ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => { // Text Generator App Session Details Including Message List const detailParams = ({ url: `/apps/${appId}/completion-conversations/${conversationId}` }) @@ -542,8 +576,8 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st } /** - * Chat App Conversation Detail Component - */ + * Chat App Conversation Detail Component + */ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => { const detailParams = { url: `/apps/${appId}/chat-conversations/${conversationId}` } const { data: conversationDetail } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchChatConversationDetail) @@ -585,8 +619,8 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } } /** - * Conversation list component including basic information - */ + * Conversation list component including basic information + */ const ConversationList: FC = ({ logs, appDetail, onRefresh }) => { const { t } = useTranslation() const { formatTime } = useTimestamp() @@ -597,6 +631,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) const [showDrawer, setShowDrawer] = useState(false) // Whether to display the chat details drawer const [currentConversation, setCurrentConversation] = useState() // Currently selected conversation const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app + const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app const { setShowPromptLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({ setShowPromptLogModal: state.setShowPromptLogModal, setShowAgentLogModal: state.setShowAgentLogModal, @@ -639,6 +674,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) {isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')} {t('appLog.table.header.endUser')} + {isChatflow && {t('appLog.table.header.status')}} {isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')} {t('appLog.table.header.userRate')} {t('appLog.table.header.adminRate')} @@ -669,6 +705,9 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) {renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)} {renderTdValue(endUser || defaultValue, !endUser)} + {isChatflow && + {statusTdRender(log.status_count)} + } {renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)} diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index e3de4a957f22e4..41db9b5d466822 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -63,6 +63,14 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { ) } + if (status === 'partial-succeeded') { + return ( +
+ + Partial Success +
+ ) + } } const onCloseDrawer = () => { From 7abc7fa5735fb9f168ba47e4a79a60373072c6af Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Fri, 20 Dec 2024 14:14:06 +0800 Subject: [PATCH 20/91] Feat: Retry on node execution errors (#11871) Co-authored-by: Novice Lee --- .../advanced_chat/generate_task_pipeline.py | 17 + .../apps/workflow/generate_task_pipeline.py | 19 +- api/core/app/apps/workflow_app_runner.py | 32 ++ api/core/app/entities/queue_entities.py | 32 ++ api/core/app/entities/task_entities.py | 70 ++++ .../task_pipeline/workflow_cycle_manage.py | 93 ++++++ api/core/helper/ssrf_proxy.py | 1 - api/core/workflow/entities/node_entities.py | 3 + .../workflow/graph_engine/entities/event.py | 7 + .../workflow/graph_engine/graph_engine.py | 310 ++++++++++-------- api/core/workflow/nodes/base/entities.py | 13 + api/core/workflow/nodes/base/node.py | 11 +- api/core/workflow/nodes/enums.py | 1 + api/core/workflow/nodes/event/__init__.py | 9 +- api/core/workflow/nodes/event/event.py | 25 ++ .../workflow/nodes/http_request/executor.py | 4 + api/core/workflow/nodes/http_request/node.py | 8 +- api/fields/workflow_run_fields.py | 14 + ...dd_retry_index_field_to_node_execution_.py | 33 ++ api/models/workflow.py | 2 + api/services/workflow_service.py | 141 +++++--- .../workflow/nodes/test_continue_on_error.py | 12 +- .../core/workflow/nodes/test_retry.py | 73 +++++ 23 files changed, 739 insertions(+), 191 deletions(-) create mode 100644 api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_retry.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 32a23a7fdb8690..ce0e95962772ad 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -22,6 +22,7 @@ QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -328,6 +329,22 @@ def _process_stream_response( workflow_node_execution=workflow_node_execution, ) + if response: + yield response + elif isinstance( + event, + QueueNodeRetryEvent, + ): + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2330229f439b97..79e5e2bcb96d60 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -18,6 +18,7 @@ QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -286,9 +287,25 @@ def _process_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if node_failed_response: yield node_failed_response + elif isinstance( + event, + QueueNodeRetryEvent, + ): + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: raise Exception("Workflow run not initialized.") diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 97c2cc5bb9bfd9..2fbf711175aab5 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -11,6 +11,7 @@ QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -38,6 +39,7 @@ NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -420,6 +422,36 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) error=event.error if isinstance(event, IterationRunFailedEvent) else None, ) ) + elif isinstance(event, NodeRunRetryEvent): + self._publish_event( + QueueNodeRetryEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.error, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + start_index=event.start_index, + ) + ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5b2036c7f9ba6a..49b7e80246a3f2 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -43,6 +43,7 @@ class QueueEvent(StrEnum): ERROR = "error" PING = "ping" STOP = "stop" + RETRY = "retry" class AppQueueEvent(BaseModel): @@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): iteration_duration_map: Optional[dict[str, float]] = None +class QueueNodeRetryEvent(AppQueueEvent): + """QueueNodeRetryEvent entity""" + + event: QueueEvent = QueueEvent.RETRY + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + retry_index: int # retry index + start_index: int # start index + + class QueueNodeInIterationFailedEvent(AppQueueEvent): """ QueueNodeInIterationFailedEvent entity diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7fe06b3af8bbb4..dd088a897816ef 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -52,6 +52,7 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + NODE_RETRY = "node_retry" PARALLEL_BRANCH_STARTED = "parallel_branch_started" PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" @@ -342,6 +343,75 @@ def to_ignore_detail_dict(self): } +class NodeRetryStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + retry_index: int = 0 + + event: StreamEvent = StreamEvent.NODE_RETRY + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + "retry_index": self.data.retry_index, + }, + } + + class ParallelBranchStartStreamResponse(StreamResponse): """ ParallelBranchStartStreamResponse entity diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 8b588b48ebe8a5..e2fa12b1cddd70 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -15,6 +15,7 @@ QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -26,6 +27,7 @@ IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, NodeFinishStreamResponse, + NodeRetryStreamResponse, NodeStartStreamResponse, ParallelBranchFinishedStreamResponse, ParallelBranchStartStreamResponse, @@ -423,6 +425,52 @@ def _handle_workflow_node_execution_failed( return workflow_node_execution + def _handle_workflow_node_execution_retried( + self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + created_at = event.start_at + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - created_at).total_seconds() + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = created_at + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.error = event.error + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.index = event.start_index + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + + return workflow_node_execution + ################################################# # to stream responses # ################################################# @@ -587,6 +635,51 @@ def _workflow_node_finish_to_stream_response( ), ) + def _workflow_node_retry_to_stream_response( + self, + event: QueueNodeRetryEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeFinishStreamResponse]: + """ + Workflow node finish to stream response. + :param event: queue node succeeded or failed event + :param task_id: task id + :param workflow_node_execution: workflow node execution + :return: + """ + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + + return NodeRetryStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeRetryStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + ), + ) + def _workflow_parallel_branch_start_to_stream_response( self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ef4516b404af13..d8aa80536412e6 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) retries = 0 - stream = kwargs.pop("stream", False) while retries <= max_retries: try: if dify_config.SSRF_PROXY_ALL_URL: diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 976a5ef74e320d..ca01dcd7d8d4a8 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -45,3 +45,6 @@ class NodeRunResult(BaseModel): error: Optional[str] = None # error message if status is failed error_type: Optional[str] = None # error type if status is failed + + # single step node run retry + retry_index: int = 0 diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 73450349ded634..999715316494fb 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunRetryEvent(BaseNodeEvent): + error: str = Field(..., description="error") + retry_index: int = Field(..., description="which retry attempt is about to be performed") + start_at: datetime = Field(..., description="retry start time") + start_index: int = Field(..., description="retry start index") + + ########################################### # Parallel Branch Events ########################################### diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 034b4bd3992dcb..e292b099684b7b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,6 +5,7 @@ from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy +from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app @@ -25,6 +26,7 @@ NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -581,7 +583,7 @@ def _run_parallel_node( def _run_node( self, - node_instance: BaseNode, + node_instance: BaseNode[BaseNodeData], route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -607,36 +609,121 @@ def _run_node( ) db.session.close() + max_retries = node_instance.node_data.retry_config.max_retries + retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + retries = 0 + shoudl_continue_retry = True + while shoudl_continue_retry and retries <= max_retries: + try: + # run node + retry_start_at = datetime.now(UTC).replace(tzinfo=None) + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id + + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if ( + retries == max_retries + and node_instance.node_type == NodeType.HTTP_REQUEST + and run_result.outputs + and not node_instance.should_continue_on_error + ): + run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + if node_instance.should_retry and retries < max_retries: + retries += 1 + self.graph_runtime_state.node_run_steps += 1 + route_node_state.node_run_result = run_result + yield NodeRunRetryEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + error=run_result.error, + retry_index=retries, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + start_at=retry_start_at, + start_index=self.graph_runtime_state.node_run_steps, + ) + time.sleep(retry_interval) + continue + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if node_instance.should_continue_on_error: + # if run failed, handle error + run_result = self._handle_continue_on_error( + node_instance, + item.run_result, + self.graph_runtime_state.variable_pool, + handle_exceptions=handle_exceptions, + ) + route_node_state.node_run_result = run_result + route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + yield NodeRunExceptionEvent( + error=run_result.error or "System Error", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + shoudl_continue_retry = False + else: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + shoudl_continue_retry = False + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if node_instance.should_continue_on_error and self.graph.edge_mapping.get( + node_instance.node_id + ): + run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) - try: - # run node - generator = node_instance.run() - for item in generator: - if isinstance(item, GraphEngineEvent): - if isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage - yield item - else: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result - route_node_state.set_finished(run_result=run_result) - - if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.should_continue_on_error: - # if run failed, handle error - run_result = self._handle_continue_on_error( - node_instance, - item.run_result, - self.graph_runtime_state.variable_pool, - handle_exceptions=handle_exceptions, - ) - route_node_state.node_run_result = run_result - route_node_state.status = RouteNodeState.Status.EXCEPTION + # append node output variables to variable pool if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively @@ -645,21 +732,23 @@ def _run_node( variable_key_list=[variable_key], variable_value=variable_value, ) - yield NodeRunExceptionEvent( - error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - else: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or "Unknown error.", + + # add parallel info to run result metadata + if parallel_id and parallel_start_node_id: + if not run_result.metadata: + run_result.metadata = {} + + run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = ( + parallel_start_node_id + ) + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) + + yield NodeRunSucceededEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, @@ -670,108 +759,59 @@ def _run_node( parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) + shoudl_continue_retry = False - elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - if node_instance.should_continue_on_error and self.graph.edge_mapping.get( - node_instance.node_id - ): - run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS - if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - # plus state total_tokens - self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] - ) - - if run_result.llm_usage: - # use the latest usage - self.graph_runtime_state.llm_usage += run_result.llm_usage - - # append node output variables to variable pool - if run_result.outputs: - for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node_instance.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, - ) - - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: - if not run_result.metadata: - run_result.metadata = {} - - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id - if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( - parent_parallel_start_node_id - ) - - yield NodeRunSucceededEvent( + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - - break - elif isinstance(item, RunStreamChunkEvent): - yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - elif isinstance(item, RunRetrieverResourceEvent): - yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - except GenerateTaskStoppedError: - # trigger node run failed event - route_node_state.status = RouteNodeState.Status.FAILED - route_node_state.failed_reason = "Workflow stopped." - yield NodeRunFailedEvent( - error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - return - except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") - raise e - finally: - db.session.close() + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + except GenerateTaskStoppedError: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed") + raise e + finally: + db.session.close() def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 9271867afffa6e..529fd7be74e9a1 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -106,12 +106,25 @@ def validate_value_type(self) -> "DefaultValue": return self +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds + retry_enabled: bool = False # whether retry is enabled + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None version: str = "1" + retry_config: RetryConfig = RetryConfig() @property def default_value_dict(self): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index ee274c2501812e..b799e7426616e7 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from models.workflow import WorkflowNodeExecutionStatus @@ -147,3 +147,12 @@ def should_continue_on_error(self) -> bool: bool: if should continue on error """ return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + + @property + def should_retry(self) -> bool: + """judge if should retry + + Returns: + bool: if should retry + """ + return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 6d8ca6f7018cb8..32fdc048d133c0 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum): CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] +RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST] diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py index 5e3b31e48baa9e..08c47d5e57387b 100644 --- a/api/core/workflow/nodes/event/__init__.py +++ b/api/core/workflow/nodes/event/__init__.py @@ -1,4 +1,10 @@ -from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .event import ( + ModelInvokeCompletedEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunRetryEvent, + RunStreamChunkEvent, +) from .types import NodeEvent __all__ = [ @@ -6,5 +12,6 @@ "NodeEvent", "RunCompletedEvent", "RunRetrieverResourceEvent", + "RunRetryEvent", "RunStreamChunkEvent", ] diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index b7034561bf6713..0dc35e7d7740ef 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,7 +1,10 @@ +from datetime import datetime + from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus class RunCompletedEvent(BaseModel): @@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel): text: str usage: LLMUsage finish_reason: str | None = None + + +class RunRetryEvent(BaseModel): + """Node Run Retry event""" + + error: str = Field(..., description="error") + retry_index: int = Field(..., description="Retry attempt number") + start_at: datetime = Field(..., description="Retry start time") + + +class SingleStepRetryEvent(BaseModel): + """Single step retry event""" + + status: str = WorkflowNodeExecutionStatus.RETRY.value + + inputs: dict | None = Field(..., description="input") + error: str = Field(..., description="error") + outputs: dict = Field(..., description="output") + retry_index: int = Field(..., description="Retry attempt number") + error: str = Field(..., description="error") + elapsed_time: float = Field(..., description="elapsed time") + execution_metadata: dict | None = Field(..., description="execution metadata") diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 90251c27a89713..92f190091badeb 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -45,6 +45,7 @@ class Executor: headers: dict[str, str] auth: HttpRequestNodeAuthorization timeout: HttpRequestNodeTimeout + max_retries: int boundary: str @@ -54,6 +55,7 @@ def __init__( node_data: HttpRequestNodeData, timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, + max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -73,6 +75,7 @@ def __init__( self.files = None self.data = None self.json = None + self.max_retries = max_retries # init template self.variable_pool = variable_pool @@ -241,6 +244,7 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: "params": self.params, "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "follow_redirects": True, + "max_retries": self.max_retries, } # request_args = {k: v for k, v in request_args.items() if v is not None} try: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a0fc8acaefaaec..171389a34c1c41 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -52,6 +52,11 @@ def get_default_config(cls, filters: dict | None = None) -> dict: "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, + "retry_config": { + "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES, + "retry_interval": 0.5 * (2**2), + "retry_enabled": True, + }, } def _run(self) -> NodeRunResult: @@ -61,12 +66,13 @@ def _run(self) -> NodeRunResult: node_data=self.node_data, timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, + max_retries=0, ) process_data["request"] = http_executor.to_log() response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and self.should_continue_on_error: + if not response.response.is_success and (self.should_continue_on_error or self.should_retry): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs={ diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 8390c665561841..7c01ffc2c62e39 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -29,6 +29,7 @@ "created_at": TimestampField, "finished_at": TimestampField, "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_for_list_fields = { @@ -45,6 +46,7 @@ "created_at": TimestampField, "finished_at": TimestampField, "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_pagination_fields = { @@ -79,6 +81,17 @@ "exceptions_count": fields.Integer, } +retry_event_field = { + "error": fields.String, + "retry_index": fields.Integer, + "inputs": fields.Raw(attribute="inputs"), + "elapsed_time": fields.Float, + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), + "status": fields.String, + "outputs": fields.Raw(attribute="outputs"), +} + + workflow_run_node_execution_fields = { "id": fields.String, "index": fields.Integer, @@ -99,6 +112,7 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "finished_at": TimestampField, + "retry_events": fields.List(fields.Nested(retry_event_field)), } workflow_run_node_execution_list_fields = { diff --git a/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py new file mode 100644 index 00000000000000..eef8224dea7883 --- /dev/null +++ b/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py @@ -0,0 +1,33 @@ +"""add retry_index field to node-execution model + +Revision ID: 348cb0a93d53 +Revises: cf8f4fc45278 +Create Date: 2024-12-16 01:23:13.093432 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '348cb0a93d53' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_column('retry_index') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index 4ddf2e9082f10c..e933382a841a6f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum): SUCCEEDED = "succeeded" FAILED = "failed" EXCEPTION = "exception" + RETRY = "retry" @classmethod def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": @@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model): created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) + retry_index = db.Column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 84768d5af053e4..ead552d6c2e83e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,6 +15,7 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.event.event import SingleStepRetryEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -220,56 +221,99 @@ def run_draft_workflow_node( # run draft workflow node start_at = time.perf_counter() + retries = 0 + max_retries = 0 + should_retry = True + retry_events = [] try: - node_instance, generator = WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) - node_run_result: NodeRunResult | None = None - for event in generator: - if isinstance(event, RunCompletedEvent): - node_run_result = event.run_result - - # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) - break - - if not node_run_result: - raise ValueError("Node run failed with no run result") - # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: - node_error_args = { - "status": WorkflowNodeExecutionStatus.EXCEPTION, - "error": node_run_result.error, - "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, - } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - **node_instance.node_data.default_value_dict, - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) - else: - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) - run_succeeded = node_run_result.status in ( - WorkflowNodeExecutionStatus.SUCCEEDED, - WorkflowNodeExecutionStatus.EXCEPTION, - ) - error = node_run_result.error if not run_succeeded else None + while retries <= max_retries and should_retry: + retry_start_at = time.perf_counter() + node_instance, generator = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + node_instance = cast(BaseNode[BaseNodeData], node_instance) + max_retries = ( + node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0 + ) + retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: + if ( + retries == max_retries + and node_instance.node_type == NodeType.HTTP_REQUEST + and node_run_result.outputs + and not node_instance.should_continue_on_error + ): + node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + should_retry = False + else: + if node_instance.should_retry: + node_run_result.status = WorkflowNodeExecutionStatus.RETRY + retries += 1 + node_run_result.retry_index = retries + retry_events.append( + SingleStepRetryEvent( + inputs=WorkflowEntry.handle_special_values(node_run_result.inputs) + if node_run_result.inputs + else None, + error=node_run_result.error, + outputs=WorkflowEntry.handle_special_values(node_run_result.outputs) + if node_run_result.outputs + else None, + retry_index=node_run_result.retry_index, + elapsed_time=time.perf_counter() - retry_start_at, + execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata) + if node_run_result.metadata + else None, + ) + ) + time.sleep(retry_interval) + else: + should_retry = False + if node_instance.should_continue_on_error: + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) + error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: node_instance = e.node_instance run_succeeded = False @@ -318,6 +362,7 @@ def run_draft_workflow_node( db.session.add(workflow_node_execution) db.session.commit() + workflow_node_execution.retry_events = retry_events return workflow_node_execution diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ba209e4020afad..2d74be9da9a96c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -2,7 +2,6 @@ from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, NodeRunExceptionEvent, NodeRunStreamChunkEvent, ) @@ -14,7 +13,9 @@ class ContinueOnErrorTestHelper: @staticmethod - def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): + def get_code_node( + code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} + ): """Helper method to create a code node configuration""" node = { "id": "node", @@ -26,6 +27,7 @@ def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: "code_language": "python3", "code": "\n".join([line[4:] for line in code.split("\n")]), "type": "code", + **retry_config, }, } if default_value: @@ -34,7 +36,10 @@ def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: @staticmethod def get_http_node( - error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False + error_strategy: str = "fail-branch", + default_value: dict | None = None, + authorization_success: bool = False, + retry_config: dict = {}, ): """Helper method to create a http node configuration""" authorization = ( @@ -65,6 +70,7 @@ def get_http_node( "body": None, "type": "http-request", "error_strategy": error_strategy, + **retry_config, }, } if default_value: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py new file mode 100644 index 00000000000000..c232875ce57d3a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_retry.py @@ -0,0 +1,73 @@ +from core.workflow.graph_engine.entities.event import ( + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunSucceededEvent, + NodeRunRetryEvent, +) +from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + + +def test_retry_default_value_partial_success(): + """retry default value node with partial success status""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", + [{"key": "result", "type": "string", "value": "http node got error response"}], + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert events[-1].outputs == {"answer": "http node got error response"} + assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) + assert len(events) == 11 + + +def test_retry_failed(): + """retry failed with success status""" + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + None, + None, + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert any(isinstance(e, GraphRunFailedEvent) for e in events) + assert len(events) == 8 From 3335fa78fc554d423bd7d2c9ba02d1818b551cab Mon Sep 17 00:00:00 2001 From: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:14:27 +0800 Subject: [PATCH 21/91] fix: node 22 build (#11883) --- web/package.json | 2 +- web/yarn.lock | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/web/package.json b/web/package.json index a1ba2cbd43d9bd..25cf231a95ed63 100644 --- a/web/package.json +++ b/web/package.json @@ -50,7 +50,7 @@ "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", "dayjs": "^1.11.7", - "echarts": "^5.4.1", + "echarts": "^5.5.1", "echarts-for-react": "^3.0.2", "elkjs": "^0.9.3", "emoji-mart": "^5.5.2", diff --git a/web/yarn.lock b/web/yarn.lock index 035b9f3867d3a9..c9590d6b066613 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -5903,13 +5903,13 @@ echarts-for-react@^3.0.2: fast-deep-equal "^3.1.3" size-sensor "^1.0.1" -echarts@^5.4.1: - version "5.4.2" - resolved "https://registry.npmjs.org/echarts/-/echarts-5.4.2.tgz" - integrity sha512-2W3vw3oI2tWJdyAz+b8DuWS0nfXtSDqlDmqgin/lfzbkB01cuMEN66KWBlmur3YMp5nEDEEt5s23pllnAzB4EA== +echarts@^5.5.1: + version "5.5.1" + resolved "https://registry.yarnpkg.com/echarts/-/echarts-5.5.1.tgz#8dc9c68d0c548934bedcb5f633db07ed1dd2101c" + integrity sha512-Fce8upazaAXUVUVsjgV6mBnGuqgO+JNDlcgF79Dksy4+wgGpQB2lmYoO4TSweFg/mZITdpGHomw/cNBJZj1icA== dependencies: tslib "2.3.0" - zrender "5.4.3" + zrender "5.6.0" electron-to-chromium@^1.5.41: version "1.5.52" @@ -13330,10 +13330,10 @@ zod@^3.23.6: resolved "https://registry.npmjs.org/zod/-/zod-3.23.8.tgz" integrity sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g== -zrender@5.4.3: - version "5.4.3" - resolved "https://registry.npmjs.org/zrender/-/zrender-5.4.3.tgz" - integrity sha512-DRUM4ZLnoaT0PBVvGBDO9oWIDBKFdAVieNWxWwK0niYzJCMwGchRk21/hsE+RKkIveH3XHCyvXcJDkgLVvfizQ== +zrender@5.6.0: + version "5.6.0" + resolved "https://registry.yarnpkg.com/zrender/-/zrender-5.6.0.tgz#01325b0bb38332dd5e87a8dbee7336cafc0f4a5b" + integrity sha512-uzgraf4njmmHAbEUxMJ8Oxg+P3fT04O+9p7gY+wJRVxo8Ge+KmYv0WJev945EH4wFuc4OY2NLXz46FZrWS9xJg== dependencies: tslib "2.3.0" From e2cde628bb676a45f5606e9d04bdf5dd9a506571 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:19:47 +0800 Subject: [PATCH 22/91] chore: translate i18n files (#11855) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- web/i18n/de-DE/tools.ts | 2 ++ web/i18n/es-ES/tools.ts | 2 ++ web/i18n/fa-IR/tools.ts | 2 ++ web/i18n/fr-FR/tools.ts | 2 ++ web/i18n/hi-IN/tools.ts | 2 ++ web/i18n/it-IT/tools.ts | 2 ++ web/i18n/ja-JP/tools.ts | 2 ++ web/i18n/ko-KR/tools.ts | 2 ++ web/i18n/pl-PL/tools.ts | 2 ++ web/i18n/pt-BR/tools.ts | 2 ++ web/i18n/ro-RO/tools.ts | 2 ++ web/i18n/ru-RU/tools.ts | 2 ++ web/i18n/sl-SI/tools.ts | 2 ++ web/i18n/th-TH/tools.ts | 2 ++ web/i18n/tr-TR/tools.ts | 2 ++ web/i18n/uk-UA/tools.ts | 2 ++ web/i18n/vi-VN/tools.ts | 2 ++ web/i18n/zh-Hant/tools.ts | 2 ++ 18 files changed, 36 insertions(+) diff --git a/web/i18n/de-DE/tools.ts b/web/i18n/de-DE/tools.ts index 3be01b83504b71..2448b3ed8f4605 100644 --- a/web/i18n/de-DE/tools.ts +++ b/web/i18n/de-DE/tools.ts @@ -144,6 +144,8 @@ const translation = { emptyTitle: 'Kein Workflow-Tool verfügbar', type: 'Art', emptyTip: 'Gehen Sie zu "Workflow -> Als Tool veröffentlichen"', + emptyTitleCustom: 'Kein benutzerdefiniertes Tool verfügbar', + emptyTipCustom: 'Erstellen eines benutzerdefinierten Werkzeugs', }, toolNameUsageTip: 'Name des Tool-Aufrufs für die Argumentation und Aufforderung des Agenten', customToolTip: 'Erfahren Sie mehr über benutzerdefinierte Dify-Tools', diff --git a/web/i18n/es-ES/tools.ts b/web/i18n/es-ES/tools.ts index 546591f1aabf13..08c9f2026dd961 100644 --- a/web/i18n/es-ES/tools.ts +++ b/web/i18n/es-ES/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Administrar en Herramientas', emptyTitle: 'No hay herramientas de flujo de trabajo disponibles', emptyTip: 'Ir a "Flujo de Trabajo -> Publicar como Herramienta"', + emptyTitleCustom: 'No hay herramienta personalizada disponible', + emptyTipCustom: 'Crear una herramienta personalizada', }, createTool: { title: 'Crear Herramienta Personalizada', diff --git a/web/i18n/fa-IR/tools.ts b/web/i18n/fa-IR/tools.ts index 002f55d1d4adbf..60a89d0f32abd7 100644 --- a/web/i18n/fa-IR/tools.ts +++ b/web/i18n/fa-IR/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'مدیریت در ابزارها', emptyTitle: 'هیچ ابزار جریان کاری در دسترس نیست', emptyTip: 'به "جریان کاری -> انتشار به عنوان ابزار" بروید', + emptyTipCustom: 'ایجاد یک ابزار سفارشی', + emptyTitleCustom: 'هیچ ابزار سفارشی در دسترس نیست', }, createTool: { title: 'ایجاد ابزار سفارشی', diff --git a/web/i18n/fr-FR/tools.ts b/web/i18n/fr-FR/tools.ts index 34c71e776400ef..5a7e47906f698e 100644 --- a/web/i18n/fr-FR/tools.ts +++ b/web/i18n/fr-FR/tools.ts @@ -144,6 +144,8 @@ const translation = { category: 'catégorie', manageInTools: 'Gérer dans Outils', emptyTip: 'Allez dans « Flux de travail -> Publier en tant qu’outil »', + emptyTitleCustom: 'Aucun outil personnalisé disponible', + emptyTipCustom: 'Créer un outil personnalisé', }, openInStudio: 'Ouvrir dans Studio', customToolTip: 'En savoir plus sur les outils personnalisés Dify', diff --git a/web/i18n/hi-IN/tools.ts b/web/i18n/hi-IN/tools.ts index 6b0cccebadc2e5..2060682931df5b 100644 --- a/web/i18n/hi-IN/tools.ts +++ b/web/i18n/hi-IN/tools.ts @@ -32,6 +32,8 @@ const translation = { manageInTools: 'उपकरणों में प्रबंधित करें', emptyTitle: 'कोई कार्यप्रवाह उपकरण उपलब्ध नहीं', emptyTip: 'कार्यप्रवाह -> उपकरण के रूप में प्रकाशित पर जाएं', + emptyTipCustom: 'एक कस्टम टूल बनाएं', + emptyTitleCustom: 'कोई कस्टम टूल उपलब्ध नहीं है', }, createTool: { title: 'कस्टम उपकरण बनाएं', diff --git a/web/i18n/it-IT/tools.ts b/web/i18n/it-IT/tools.ts index 00e7cad58c779c..f9512fb20d32c5 100644 --- a/web/i18n/it-IT/tools.ts +++ b/web/i18n/it-IT/tools.ts @@ -32,6 +32,8 @@ const translation = { manageInTools: 'Gestisci in Strumenti', emptyTitle: 'Nessun strumento di flusso di lavoro disponibile', emptyTip: 'Vai a `Flusso di lavoro -> Pubblica come Strumento`', + emptyTitleCustom: 'Nessun attrezzo personalizzato disponibile', + emptyTipCustom: 'Creare uno strumento personalizzato', }, createTool: { title: 'Crea Strumento Personalizzato', diff --git a/web/i18n/ja-JP/tools.ts b/web/i18n/ja-JP/tools.ts index 12d3634715c450..f52f101f52e3ac 100644 --- a/web/i18n/ja-JP/tools.ts +++ b/web/i18n/ja-JP/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'ツールリストに移動して管理する', emptyTitle: '利用可能なワークフローツールはありません', emptyTip: '追加するには、「ワークフロー -> ツールとして公開 」に移動する', + emptyTitleCustom: 'カスタムツールはありません', + emptyTipCustom: 'カスタムツールの作成', }, createTool: { title: 'カスタムツールを作成する', diff --git a/web/i18n/ko-KR/tools.ts b/web/i18n/ko-KR/tools.ts index c896a17a4f5716..0b9f45178497e2 100644 --- a/web/i18n/ko-KR/tools.ts +++ b/web/i18n/ko-KR/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: '도구에서 관리', emptyTitle: '사용 가능한 워크플로우 도구 없음', emptyTip: '"워크플로우 -> 도구로 등록하기"로 이동', + emptyTipCustom: '사용자 지정 도구 만들기', + emptyTitleCustom: '사용 가능한 사용자 지정 도구가 없습니다.', }, createTool: { title: '커스텀 도구 만들기', diff --git a/web/i18n/pl-PL/tools.ts b/web/i18n/pl-PL/tools.ts index f34825b049651d..768883522e83c7 100644 --- a/web/i18n/pl-PL/tools.ts +++ b/web/i18n/pl-PL/tools.ts @@ -148,6 +148,8 @@ const translation = { add: 'dodawać', emptyTitle: 'Brak dostępnego narzędzia do przepływu pracy', emptyTip: 'Przejdź do "Przepływ pracy -> Opublikuj jako narzędzie"', + emptyTitleCustom: 'Brak dostępnego narzędzia niestandardowego', + emptyTipCustom: 'Tworzenie narzędzia niestandardowego', }, openInStudio: 'Otwieranie w Studio', customToolTip: 'Dowiedz się więcej o niestandardowych narzędziach Dify', diff --git a/web/i18n/pt-BR/tools.ts b/web/i18n/pt-BR/tools.ts index 1b207153286a9f..8af475a98accc8 100644 --- a/web/i18n/pt-BR/tools.ts +++ b/web/i18n/pt-BR/tools.ts @@ -144,6 +144,8 @@ const translation = { emptyTitle: 'Nenhuma ferramenta de fluxo de trabalho disponível', added: 'Adicionado', manageInTools: 'Gerenciar em Ferramentas', + emptyTitleCustom: 'Nenhuma ferramenta personalizada disponível', + emptyTipCustom: 'Criar uma ferramenta personalizada', }, openInStudio: 'Abrir no Studio', customToolTip: 'Saiba mais sobre as ferramentas personalizadas da Dify', diff --git a/web/i18n/ro-RO/tools.ts b/web/i18n/ro-RO/tools.ts index 165bdb26ede683..baeffb2b66d680 100644 --- a/web/i18n/ro-RO/tools.ts +++ b/web/i18n/ro-RO/tools.ts @@ -144,6 +144,8 @@ const translation = { type: 'tip', emptyTitle: 'Nu este disponibil niciun instrument de flux de lucru', emptyTip: 'Accesați "Flux de lucru -> Publicați ca instrument"', + emptyTitleCustom: 'Nu este disponibil niciun instrument personalizat', + emptyTipCustom: 'Crearea unui instrument personalizat', }, openInStudio: 'Deschide în Studio', customToolTip: 'Aflați mai multe despre instrumentele personalizate Dify', diff --git a/web/i18n/ru-RU/tools.ts b/web/i18n/ru-RU/tools.ts index e0dfd571b26be3..4749fee16377a3 100644 --- a/web/i18n/ru-RU/tools.ts +++ b/web/i18n/ru-RU/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Управлять в инструментах', emptyTitle: 'Нет доступных инструментов рабочего процесса', emptyTip: 'Перейдите в "Рабочий процесс -> Опубликовать как инструмент"', + emptyTitleCustom: 'Нет пользовательского инструмента', + emptyTipCustom: 'Создание пользовательского инструмента', }, createTool: { title: 'Создать пользовательский инструмент', diff --git a/web/i18n/sl-SI/tools.ts b/web/i18n/sl-SI/tools.ts index 57160cfe62309f..63b508a05d4da2 100644 --- a/web/i18n/sl-SI/tools.ts +++ b/web/i18n/sl-SI/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Upravljaj v Orodjih', emptyTitle: 'Orodje za potek dela ni na voljo', emptyTip: 'Pojdite na "Potek dela -> Objavi kot orodje"', + emptyTipCustom: 'Ustvarjanje orodja po meri', + emptyTitleCustom: 'Orodje po meri ni na voljo', }, createTool: { title: 'Ustvari prilagojeno orodje', diff --git a/web/i18n/th-TH/tools.ts b/web/i18n/th-TH/tools.ts index a3e12bafd0a367..98272e83f585db 100644 --- a/web/i18n/th-TH/tools.ts +++ b/web/i18n/th-TH/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'จัดการในเครื่องมือ', emptyTitle: 'ไม่มีเครื่องมือเวิร์กโฟลว์', emptyTip: 'ไปที่ "เวิร์กโฟลว์ -> เผยแพร่เป็นเครื่องมือ"', + emptyTitleCustom: 'ไม่มีเครื่องมือที่กําหนดเอง', + emptyTipCustom: 'สร้างเครื่องมือแบบกําหนดเอง', }, createTool: { title: 'สร้างเครื่องมือที่กําหนดเอง', diff --git a/web/i18n/tr-TR/tools.ts b/web/i18n/tr-TR/tools.ts index 00af8ed7f2ecfb..a579ac82f1f222 100644 --- a/web/i18n/tr-TR/tools.ts +++ b/web/i18n/tr-TR/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Araçlarda Yönet', emptyTitle: 'Kullanılabilir workflow aracı yok', emptyTip: 'Git "Workflow -> Araç olarak Yayınla"', + emptyTitleCustom: 'Özel bir araç yok', + emptyTipCustom: 'Özel bir araç oluşturun', }, createTool: { title: 'Özel Araç Oluştur', diff --git a/web/i18n/uk-UA/tools.ts b/web/i18n/uk-UA/tools.ts index 309a450afc1bb5..f84d0d82cc516d 100644 --- a/web/i18n/uk-UA/tools.ts +++ b/web/i18n/uk-UA/tools.ts @@ -144,6 +144,8 @@ const translation = { manageInTools: 'Керування в інструментах', emptyTip: 'Перейдіть до розділу "Робочий процес -> Опублікувати як інструмент"', emptyTitle: 'Немає доступного інструменту для роботи з робочими процесами', + emptyTitleCustom: 'Немає доступного спеціального інструменту', + emptyTipCustom: 'Створення власного інструмента', }, openInStudio: 'Відкрити в Студії', customToolTip: 'Дізнайтеся більше про користувацькі інструменти Dify', diff --git a/web/i18n/vi-VN/tools.ts b/web/i18n/vi-VN/tools.ts index b03a6ccc98e35b..86c55166f950cf 100644 --- a/web/i18n/vi-VN/tools.ts +++ b/web/i18n/vi-VN/tools.ts @@ -144,6 +144,8 @@ const translation = { added: 'Thêm', emptyTip: 'Đi tới "Quy trình làm việc -> Xuất bản dưới dạng công cụ"', emptyTitle: 'Không có sẵn công cụ quy trình làm việc', + emptyTitleCustom: 'Không có công cụ tùy chỉnh nào có sẵn', + emptyTipCustom: 'Tạo công cụ tùy chỉnh', }, toolNameUsageTip: 'Tên cuộc gọi công cụ để lý luận và nhắc nhở tổng đài viên', customToolTip: 'Tìm hiểu thêm về các công cụ tùy chỉnh Dify', diff --git a/web/i18n/zh-Hant/tools.ts b/web/i18n/zh-Hant/tools.ts index d45980c0175dcf..40a63eff653ac5 100644 --- a/web/i18n/zh-Hant/tools.ts +++ b/web/i18n/zh-Hant/tools.ts @@ -144,6 +144,8 @@ const translation = { category: '類別', emptyTitle: '沒有可用的工作流程工具', emptyTip: '轉到“工作流 - >發佈為工具”', + emptyTipCustom: '創建自訂工具', + emptyTitleCustom: '沒有可用的自訂工具', }, customToolTip: '瞭解有關 Dify 自訂工具的更多資訊', toolNameUsageTip: '用於代理推理和提示的工具調用名稱', From 52201d95b1185aaf88e9707f6a16002eaec3dbd9 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Fri, 20 Dec 2024 14:40:33 +0800 Subject: [PATCH 23/91] chore: add retry index migration (#11887) Co-authored-by: Novice Lee --- ...e15e_add_retry_index_field_to_node_execution_.py} | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename api/migrations/versions/{2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py => 2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py} (78%) diff --git a/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py similarity index 78% rename from api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py rename to api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py index eef8224dea7883..3254c23c96192d 100644 --- a/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py +++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py @@ -1,8 +1,8 @@ -"""add retry_index field to node-execution model +"""add retry_index field to node-execution model -Revision ID: 348cb0a93d53 -Revises: cf8f4fc45278 -Create Date: 2024-12-16 01:23:13.093432 +Revision ID: e1944c35e15e +Revises: 11b07f66c737 +Create Date: 2024-12-20 06:28:30.287197 """ from alembic import op @@ -11,8 +11,8 @@ # revision identifiers, used by Alembic. -revision = '348cb0a93d53' -down_revision = 'cf8f4fc45278' +revision = 'e1944c35e15e' +down_revision = '11b07f66c737' branch_labels = None depends_on = None From 7b03a0316d97f9b315a024e8a645bff1e8757f21 Mon Sep 17 00:00:00 2001 From: yihong Date: Fri, 20 Dec 2024 14:51:43 +0800 Subject: [PATCH 24/91] fix: better memory usage from 800+ to 500+ (#11796) Signed-off-by: yihong0618 --- .../model_providers/vertex_ai/llm/llm.py | 26 +++++++++++++------ .../text_embedding/text_embedding.py | 18 ++++++++++--- .../jieba/jieba_keyword_table_handler.py | 15 ++++++----- .../rag/datasource/vdb/oracle/oraclevector.py | 6 +++-- .../workflow/nodes/document_extractor/node.py | 17 +++++++----- 5 files changed, 56 insertions(+), 26 deletions(-) diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 934195cc3d6fa2..c50e0f794616b3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -4,11 +4,10 @@ import logging import time from collections.abc import Generator -from typing import Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import google.auth.transport.requests import requests -import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream from anthropic.types import ( ContentBlockDeltaEvent, @@ -19,8 +18,6 @@ MessageStreamEvent, ) from google.api_core import exceptions -from google.cloud import aiplatform -from google.oauth2 import service_account from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -47,6 +44,9 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +if TYPE_CHECKING: + import vertexai.generative_models as glm + logger = logging.getLogger(__name__) @@ -102,6 +102,8 @@ def _generate_anthropic( :param stream: is stream response :return: full response or stream response chunk generator result """ + from google.oauth2 import service_account + # use Anthropic official SDK references # - https://github.com/anthropics/anthropic-sdk-python service_account_key = credentials.get("vertex_service_account_key", "") @@ -406,13 +408,15 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: return text.rstrip() - def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool": """ Convert tool messages to glm tools :param tools: tool messages :return: glm tools """ + import vertexai.generative_models as glm + return glm.Tool( function_declarations=[ glm.FunctionDeclaration( @@ -473,6 +477,10 @@ def _generate( :param user: unique user id :return: full response or stream response chunk generator result """ + import vertexai.generative_models as glm + from google.cloud import aiplatform + from google.oauth2 import service_account + config_kwargs = model_parameters.copy() config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) @@ -522,7 +530,7 @@ def _generate( return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage] ) -> LLMResult: """ Handle llm response @@ -554,7 +562,7 @@ def _handle_generate_response( return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage] ) -> Generator: """ Handle llm stream response @@ -638,13 +646,15 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: return message_text - def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: + def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content": """ Format a single message into glm.Content for Google API :param message: one PromptMessage :return: glm Content representation of message """ + import vertexai.generative_models as glm + if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index eb54941e086752..b8b0e5f15acb44 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -2,12 +2,9 @@ import json import time from decimal import Decimal -from typing import Optional +from typing import TYPE_CHECKING, Optional import tiktoken -from google.cloud import aiplatform -from google.oauth2 import service_account -from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -24,6 +21,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi +if TYPE_CHECKING: + from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel +else: + VertexTextEmbeddingModel = None + class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): """ @@ -48,6 +50,10 @@ def _invoke( :param input_type: input type :return: embeddings result """ + from google.cloud import aiplatform + from google.oauth2 import service_account + from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] @@ -100,6 +106,10 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ + from google.cloud import aiplatform + from google.oauth2 import service_account + from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel + try: service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 4b1ade8e3fa095..ec809cf325306e 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,18 +1,19 @@ import re from typing import Optional -import jieba -from jieba.analyse import default_tfidf - -from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - class JiebaKeywordTableHandler: def __init__(self): - default_tfidf.stop_words = STOPWORDS + import jieba.analyse + + from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + jieba.analyse.default_tfidf.stop_words = STOPWORDS def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" + import jieba + keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, @@ -22,6 +23,8 @@ def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10 def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" + from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + results = set() for token in tokens: results.add(token) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 71c58c9d0c37d5..74608f1e1a3b05 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -6,10 +6,8 @@ from typing import Any import jieba.posseg as pseg -import nltk import numpy import oracledb -from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config @@ -202,6 +200,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # lazy import + import nltk + from nltk.corpus import stopwords + top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later score_threshold = float(kwargs.get("score_threshold") or 0.0) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 59afe7ac87a855..95d0ea3aab54d6 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -8,12 +8,6 @@ import pandas as pd import pypdfium2 # type: ignore import yaml # type: ignore -from unstructured.partition.api import partition_via_api -from unstructured.partition.email import partition_email -from unstructured.partition.epub import partition_epub -from unstructured.partition.msg import partition_msg -from unstructured.partition.ppt import partition_ppt -from unstructured.partition.pptx import partition_pptx from configs import dify_config from core.file import File, FileTransferMethod, file_manager @@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str: def _extract_text_from_ppt(file_content: bytes) -> str: + from unstructured.partition.ppt import partition_ppt + try: with io.BytesIO(file_content) as file: elements = partition_ppt(file=file) @@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str: + from unstructured.partition.api import partition_via_api + from unstructured.partition.pptx import partition_pptx + try: if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: @@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str: def _extract_text_from_epub(file_content: bytes) -> str: + from unstructured.partition.epub import partition_epub + try: with io.BytesIO(file_content) as file: elements = partition_epub(file=file) @@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str: def _extract_text_from_eml(file_content: bytes) -> str: + from unstructured.partition.email import partition_email + try: with io.BytesIO(file_content) as file: elements = partition_email(file=file) @@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str: def _extract_text_from_msg(file_content: bytes) -> str: + from unstructured.partition.msg import partition_msg + try: with io.BytesIO(file_content) as file: elements = partition_msg(file=file) From dacd457478b7646b20c41988a62b4d5c1801c21f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 20 Dec 2024 14:52:20 +0800 Subject: [PATCH 25/91] feat: add workflow parallel depth limit configuration (#11460) Signed-off-by: -LAN- Co-authored-by: zxhlyh --- api/.env.example | 1 + api/configs/feature/__init__.py | 5 +++++ api/controllers/console/app/workflow.py | 15 +++++++++++++++ api/core/workflow/graph_engine/entities/graph.py | 5 ++++- api/tests/unit_tests/configs/test_dify_config.py | 2 ++ docker/.env.example | 1 + docker/docker-compose.yaml | 3 ++- web/app/components/workflow/hooks/use-workflow.ts | 9 ++++++--- web/service/use-workflow.ts | 12 ++++++++++++ web/types/workflow.ts | 4 ++++ 10 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 web/service/use-workflow.ts diff --git a/api/.env.example b/api/.env.example index 5652b9c346eed5..071a200e680278 100644 --- a/api/.env.example +++ b/api/.env.example @@ -399,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 # App configuration diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 623c7cf1e109f9..73f8a95989baaf 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -433,6 +433,11 @@ class WorkflowConfig(BaseSettings): default=5, ) + WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field( + description="Maximum allowed depth for nested parallel executions", + default=3, + ) + MAX_VARIABLE_SIZE: PositiveInt = Field( description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", default=200 * 1024, diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index c85d554069c4ef..f228c3ec4a0e07 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from configs import dify_config from controllers.console import api from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model @@ -426,7 +427,21 @@ def post(self, app_model: App): } +class WorkflowConfigApi(Resource): + """Resource for workflow configuration.""" + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + return { + "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + } + + api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(WorkflowConfigApi, "/apps//workflows/draft/config") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 4f7bc60e26b5e2..800dd136afb57f 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field +from configs import dify_config from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter @@ -170,7 +171,9 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non for parallel in parallel_mapping.values(): if parallel.parent_parallel_id: cls._check_exceed_parallel_limit( - parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id + parallel_mapping=parallel_mapping, + level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + parent_parallel_id=parallel.parent_parallel_id, ) # init answer stream generate routes diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 385eb08c3681ff..efa9ea89794b92 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -59,6 +59,8 @@ def test_dify_config(example_env_file): # annotated field with configured value assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3 + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. diff --git a/docker/.env.example b/docker/.env.example index e5bddf7ae13dc8..e8ec246ae20b4f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -699,6 +699,7 @@ WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 75f53f5ec3d022..d0738d93054fcb 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -18,7 +18,6 @@ x-shared-env: &shared-api-worker-env LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} - SENTRY_DSN: ${SENTRY_DSN:-} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} @@ -260,6 +259,7 @@ x-shared-env: &shared-api-worker-env UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} + SENTRY_DSN: ${SENTRY_DSN:-} API_SENTRY_DSN: ${API_SENTRY_DSN:-} API_SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} API_SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} @@ -299,6 +299,7 @@ x-shared-env: &shared-api-worker-env WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200} WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5} MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} + WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 3c27b5c91bba21..2eafb4ad40419e 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -57,6 +57,7 @@ import { import I18n from '@/context/i18n' import { CollectionType } from '@/app/components/tools/types' import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' +import { useWorkflowConfig } from '@/service/use-workflow' export const useIsChatMode = () => { const appDetail = useAppStore(s => s.appDetail) @@ -69,7 +70,9 @@ export const useWorkflow = () => { const { locale } = useContext(I18n) const store = useStoreApi() const workflowStore = useWorkflowStore() + const appId = useStore(s => s.appId) const nodesExtraData = useNodesExtraData() + const { data: workflowConfig } = useWorkflowConfig(appId) const setPanelWidth = useCallback((width: number) => { localStorage.setItem('workflow-node-panel-width', `${width}`) workflowStore.setState({ panelWidth: width }) @@ -336,15 +339,15 @@ export const useWorkflow = () => { for (let i = 0; i < parallelList.length; i++) { const parallel = parallelList[i] - if (parallel.depth > PARALLEL_DEPTH_LIMIT) { + if (parallel.depth > (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT)) { const { setShowTips } = workflowStore.getState() - setShowTips(t('workflow.common.parallelTip.depthLimit', { num: PARALLEL_DEPTH_LIMIT })) + setShowTips(t('workflow.common.parallelTip.depthLimit', { num: (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT) })) return false } } return true - }, [t, workflowStore]) + }, [t, workflowStore, workflowConfig?.parallel_depth_limit]) const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { diff --git a/web/service/use-workflow.ts b/web/service/use-workflow.ts new file mode 100644 index 00000000000000..948a114b042f84 --- /dev/null +++ b/web/service/use-workflow.ts @@ -0,0 +1,12 @@ +import { useQuery } from '@tanstack/react-query' +import { get } from './base' +import type { WorkflowConfigResponse } from '@/types/workflow' + +const NAME_SPACE = 'workflow' + +export const useWorkflowConfig = (appId: string) => { + return useQuery({ + queryKey: [NAME_SPACE, 'config', appId], + queryFn: () => get(`/apps/${appId}/workflows/draft/config`), + }) +} diff --git a/web/types/workflow.ts b/web/types/workflow.ts index a5db7e635dfa5b..38f0bb5a409fd9 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -333,3 +333,7 @@ export type ConversationVariableResponse = { } export type IterationDurationMap = Record + +export type WorkflowConfigResponse = { + parallel_depth_limit: number +} From 0c0120ef27f306463ae7dc1596a235f812841cf0 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Fri, 20 Dec 2024 15:44:37 +0800 Subject: [PATCH 26/91] Feat/workflow retry (#11885) --- .../chat/chat/answer/workflow-process.tsx | 7 ++ web/app/components/base/input/index.tsx | 9 ++ web/app/components/workflow/constants.ts | 2 + .../workflow/hooks/use-workflow-run.ts | 42 ++++++- .../components/before-run-form/index.tsx | 107 ++++++++++------ .../error-handle/error-handle-on-panel.tsx | 2 - .../nodes/_base/components/retry/hooks.ts | 41 ++++++ .../_base/components/retry/retry-on-node.tsx | 88 +++++++++++++ .../_base/components/retry/retry-on-panel.tsx | 117 ++++++++++++++++++ .../_base/components/retry/style.module.css | 5 + .../nodes/_base/components/retry/types.ts | 5 + .../nodes/_base/components/retry/utils.ts | 0 .../components/workflow/nodes/_base/node.tsx | 14 ++- .../components/workflow/nodes/_base/panel.tsx | 12 ++ .../components/workflow/nodes/http/default.ts | 10 +- .../components/workflow/nodes/http/panel.tsx | 14 ++- .../components/workflow/nodes/llm/panel.tsx | 10 +- .../components/workflow/nodes/tool/panel.tsx | 13 +- .../workflow/panel/debug-and-preview/hooks.ts | 26 ++++ .../workflow/panel/workflow-preview.tsx | 23 +++- web/app/components/workflow/run/index.tsx | 34 ++++- web/app/components/workflow/run/node.tsx | 24 ++++ .../components/workflow/run/result-panel.tsx | 28 +++++ .../workflow/run/retry-result-panel.tsx | 46 +++++++ .../components/workflow/run/tracing-panel.tsx | 4 + web/app/components/workflow/types.ts | 4 + web/app/components/workflow/utils.ts | 14 +++ web/i18n/en-US/workflow.ts | 14 +++ web/i18n/zh-Hans/workflow.ts | 14 +++ web/service/base.ts | 9 +- web/types/workflow.ts | 3 + 31 files changed, 690 insertions(+), 51 deletions(-) create mode 100644 web/app/components/workflow/nodes/_base/components/retry/hooks.ts create mode 100644 web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx create mode 100644 web/app/components/workflow/nodes/_base/components/retry/style.module.css create mode 100644 web/app/components/workflow/nodes/_base/components/retry/types.ts create mode 100644 web/app/components/workflow/nodes/_base/components/retry/utils.ts create mode 100644 web/app/components/workflow/run/retry-result-panel.tsx diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 4a09e27d9818f6..bb9abdb6fcf9b9 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -64,6 +64,12 @@ const WorkflowProcessItem = ({ setShowMessageLogModal(true) }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + const showRetryDetail = useCallback(() => { + setCurrentLogItem(item) + setCurrentLogModalActiveTab('TRACING') + setShowMessageLogModal(true) + }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + return (
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index bf8efdb65ab83d..044fc278583fc0 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -28,6 +28,7 @@ export type InputProps = { destructive?: boolean wrapperClassName?: string styleCss?: CSSProperties + unit?: string } & React.InputHTMLAttributes & VariantProps const Input = ({ @@ -43,6 +44,7 @@ const Input = ({ value, placeholder, onChange, + unit, ...props }: InputProps) => { const { t } = useTranslation() @@ -80,6 +82,13 @@ const Input = ({ {destructive && ( )} + { + unit && ( +
+ {unit} +
+ ) + }
) } diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index ffa14b347b8207..d04163b853b7b3 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -506,3 +506,5 @@ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE' export const CUSTOM_NODE = 'custom' export const CUSTOM_EDGE = 'custom' export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK' +export const DEFAULT_RETRY_MAX = 3 +export const DEFAULT_RETRY_INTERVAL = 100 diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index a01b2d31549961..822aa490dbe4ec 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -28,6 +28,7 @@ import { getFilesInLogs, } from '@/app/components/base/file-uploader/utils' import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types' +import type { NodeTracing } from '@/types/workflow' export const useWorkflowRun = () => { const store = useStoreApi() @@ -114,6 +115,7 @@ export const useWorkflowRun = () => { onIterationStart, onIterationNext, onIterationFinish, + onNodeRetry, onError, ...restCallback } = callback || {} @@ -440,10 +442,13 @@ export const useWorkflowRun = () => { }) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { + ...data, ...(draft.tracing[currentIndex].extras ? { extras: draft.tracing[currentIndex].extras } : {}), - ...data, + ...(draft.tracing[currentIndex].retryDetail + ? { retryDetail: draft.tracing[currentIndex].retryDetail } + : {}), } as any } })) @@ -616,6 +621,41 @@ export const useWorkflowRun = () => { if (onIterationFinish) onIterationFinish(params) }, + onNodeRetry: (params) => { + const { data } = params + const { + workflowRunningData, + setWorkflowRunningData, + } = workflowStore.getState() + const { + getNodes, + setNodes, + } = store.getState() + + const nodes = getNodes() + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const currentRetryNodeIndex = tracing.findIndex(trace => trace.node_id === data.node_id) + + if (currentRetryNodeIndex > -1) { + const currentRetryNode = tracing[currentRetryNodeIndex] + if (currentRetryNode.retryDetail) + draft.tracing![currentRetryNodeIndex].retryDetail!.push(data as NodeTracing) + + else + draft.tracing![currentRetryNodeIndex].retryDetail = [data as NodeTracing] + } + })) + const newNodes = produce(nodes, (draft) => { + const currentNode = draft.find(node => node.id === data.node_id)! + + currentNode.data._retryIndex = data.retry_index + }) + setNodes(newNodes) + + if (onNodeRetry) + onNodeRetry(params) + }, onParallelBranchStarted: (params) => { // console.log(params, 'parallel start') }, diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx index 92a4deb51381b6..d7e2a953dae413 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx @@ -17,17 +17,25 @@ import ResultPanel from '@/app/components/workflow/run/result-panel' import Toast from '@/app/components/base/toast' import { TransferMethod } from '@/types/app' import { getProcessedFiles } from '@/app/components/base/file-uploader/utils' +import type { NodeTracing } from '@/types/workflow' +import RetryResultPanel from '@/app/components/workflow/run/retry-result-panel' +import type { BlockEnum } from '@/app/components/workflow/types' +import type { Emoji } from '@/app/components/tools/types' const i18nPrefix = 'workflow.singleRun' type BeforeRunFormProps = { nodeName: string + nodeType?: BlockEnum + toolIcon?: string | Emoji onHide: () => void onRun: (submitData: Record) => void onStop: () => void runningStatus: NodeRunningStatus result?: JSX.Element forms: FormProps[] + retryDetails?: NodeTracing[] + onRetryDetailBack?: any } function formatValue(value: string | any, type: InputVarType) { @@ -50,12 +58,16 @@ function formatValue(value: string | any, type: InputVarType) { } const BeforeRunForm: FC = ({ nodeName, + nodeType, + toolIcon, onHide, onRun, onStop, runningStatus, result, forms, + retryDetails, + onRetryDetailBack = () => { }, }) => { const { t } = useTranslation() @@ -122,48 +134,69 @@ const BeforeRunForm: FC = ({
{t(`${i18nPrefix}.testRun`)} {nodeName}
-
+
{ + onHide() + }}>
- -
-
- {forms.map((form, index) => ( -
-
- {index < forms.length - 1 && } + { + retryDetails?.length && ( +
+ ({ + ...item, + title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`, + node_type: nodeType!, + extras: { + icon: toolIcon!, + }, + }))} + onBack={onRetryDetailBack} + /> +
+ ) + } + { + !retryDetails?.length && ( +
+
+ {forms.map((form, index) => ( +
+ + {index < forms.length - 1 && } +
+ ))}
- ))} -
- -
- {isRunning && ( -
- +
+ {isRunning && ( +
+ +
+ )} +
- )} - -
- {isRunning && ( - - )} - {isFinished && ( - <> - {result} - - )} -
+ {isRunning && ( + + )} + {isFinished && ( + <> + {result} + + )} +
+ ) + }
) diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index f11f8bd5fb0784..89412cabb33a0e 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -14,7 +14,6 @@ import type { CommonNodeType, Node, } from '@/app/components/workflow/types' -import Split from '@/app/components/workflow/nodes/_base/components/split' import Tooltip from '@/app/components/base/tooltip' type ErrorHandleProps = Pick @@ -45,7 +44,6 @@ const ErrorHandle = ({ return ( <> -
{ + const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate() + + const handleRetryConfigChange = useCallback((value?: WorkflowRetryConfig) => { + handleNodeDataUpdateWithSyncDraft({ + id, + data: { + retry_config: value, + }, + }) + }, [id, handleNodeDataUpdateWithSyncDraft]) + + return { + handleRetryConfigChange, + } +} + +export const useRetryDetailShowInSingleRun = () => { + const [retryDetails, setRetryDetails] = useState() + + const handleRetryDetailsChange = useCallback((details: NodeTracing[] | undefined) => { + setRetryDetails(details) + }, []) + + return { + retryDetails, + handleRetryDetailsChange, + } +} diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx new file mode 100644 index 00000000000000..f5d2f08ac8507f --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx @@ -0,0 +1,88 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiAlertFill, + RiCheckboxCircleFill, + RiLoader2Line, +} from '@remixicon/react' +import type { Node } from '@/app/components/workflow/types' +import { NodeRunningStatus } from '@/app/components/workflow/types' +import cn from '@/utils/classnames' + +type RetryOnNodeProps = Pick +const RetryOnNode = ({ + data, +}: RetryOnNodeProps) => { + const { t } = useTranslation() + const { retry_config } = data + const showSelectedBorder = data.selected || data._isBundled || data._isEntering + const { + isRunning, + isSuccessful, + isException, + isFailed, + } = useMemo(() => { + return { + isRunning: data._runningStatus === NodeRunningStatus.Running && !showSelectedBorder, + isSuccessful: data._runningStatus === NodeRunningStatus.Succeeded && !showSelectedBorder, + isFailed: data._runningStatus === NodeRunningStatus.Failed && !showSelectedBorder, + isException: data._runningStatus === NodeRunningStatus.Exception && !showSelectedBorder, + } + }, [data._runningStatus, showSelectedBorder]) + const showDefault = !isRunning && !isSuccessful && !isException && !isFailed + + if (!retry_config) + return null + + return ( +
+
+
+ { + showDefault && ( + t('workflow.nodes.common.retry.retryTimes', { times: retry_config.max_retries }) + ) + } + { + isRunning && ( + <> + + {t('workflow.nodes.common.retry.retrying')} + + ) + } + { + isSuccessful && ( + <> + + {t('workflow.nodes.common.retry.retrySuccessful')} + + ) + } + { + (isFailed || isException) && ( + <> + + {t('workflow.nodes.common.retry.retryFailed')} + + ) + } +
+ { + !showDefault && ( +
+ {data._retryIndex}/{data.retry_config?.max_retries} +
+ ) + } +
+
+ ) +} + +export default RetryOnNode diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx new file mode 100644 index 00000000000000..dc877a632c5d15 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx @@ -0,0 +1,117 @@ +import { useTranslation } from 'react-i18next' +import { useRetryConfig } from './hooks' +import s from './style.module.css' +import Switch from '@/app/components/base/switch' +import Slider from '@/app/components/base/slider' +import Input from '@/app/components/base/input' +import type { + Node, +} from '@/app/components/workflow/types' +import Split from '@/app/components/workflow/nodes/_base/components/split' + +type RetryOnPanelProps = Pick +const RetryOnPanel = ({ + id, + data, +}: RetryOnPanelProps) => { + const { t } = useTranslation() + const { handleRetryConfigChange } = useRetryConfig(id) + const { retry_config } = data + + const handleRetryEnabledChange = (value: boolean) => { + handleRetryConfigChange({ + retry_enabled: value, + max_retries: retry_config?.max_retries || 3, + retry_interval: retry_config?.retry_interval || 1000, + }) + } + + const handleMaxRetriesChange = (value: number) => { + if (value > 10) + value = 10 + else if (value < 1) + value = 1 + handleRetryConfigChange({ + retry_enabled: true, + max_retries: value, + retry_interval: retry_config?.retry_interval || 1000, + }) + } + + const handleRetryIntervalChange = (value: number) => { + if (value > 5000) + value = 5000 + else if (value < 100) + value = 100 + handleRetryConfigChange({ + retry_enabled: true, + max_retries: retry_config?.max_retries || 3, + retry_interval: value, + }) + } + + return ( + <> +
+
+
+
{t('workflow.nodes.common.retry.retryOnFailure')}
+
+ handleRetryEnabledChange(v)} + /> +
+ { + retry_config?.retry_enabled && ( +
+
+
{t('workflow.nodes.common.retry.maxRetries')}
+ + handleMaxRetriesChange(e.target.value as any)} + min={1} + max={10} + unit={t('workflow.nodes.common.retry.times') || ''} + className={s.input} + /> +
+
+
{t('workflow.nodes.common.retry.retryInterval')}
+ + handleRetryIntervalChange(e.target.value as any)} + min={100} + max={5000} + unit={t('workflow.nodes.common.retry.ms') || ''} + className={s.input} + /> +
+
+ ) + } +
+ + + ) +} + +export default RetryOnPanel diff --git a/web/app/components/workflow/nodes/_base/components/retry/style.module.css b/web/app/components/workflow/nodes/_base/components/retry/style.module.css new file mode 100644 index 00000000000000..2ce8717af852ec --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/style.module.css @@ -0,0 +1,5 @@ +.input::-webkit-inner-spin-button, +.input::-webkit-outer-spin-button { + -webkit-appearance: none; + margin: 0; +} \ No newline at end of file diff --git a/web/app/components/workflow/nodes/_base/components/retry/types.ts b/web/app/components/workflow/nodes/_base/components/retry/types.ts new file mode 100644 index 00000000000000..bb5f593fd545ae --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/types.ts @@ -0,0 +1,5 @@ +export type WorkflowRetryConfig = { + max_retries: number + retry_interval: number + retry_enabled: boolean +} diff --git a/web/app/components/workflow/nodes/_base/components/retry/utils.ts b/web/app/components/workflow/nodes/_base/components/retry/utils.ts new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index f2da2da35a4fe9..4807fa3b2b7457 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,7 +25,10 @@ import { useNodesReadOnly, useToolIcon, } from '../../hooks' -import { hasErrorHandleNode } from '../../utils' +import { + hasErrorHandleNode, + hasRetryNode, +} from '../../utils' import { useNodeIterationInteractions } from '../iteration/use-interactions' import type { IterationNodeType } from '../iteration/types' import { @@ -35,6 +38,7 @@ import { import NodeResizer from './components/node-resizer' import NodeControl from './components/node-control' import ErrorHandleOnNode from './components/error-handle/error-handle-on-node' +import RetryOnNode from './components/retry/retry-on-node' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' @@ -237,6 +241,14 @@ const BaseNode: FC = ({
) } + { + hasRetryNode(data.type) && ( + + ) + } { hasErrorHandleNode(data.type) && ( = ({
{cloneElement(children, { id, data })}
+ + { + hasRetryNode(data.type) && ( + + ) + } { hasErrorHandleNode(data.type) && ( = { defaultValue: { @@ -24,6 +27,11 @@ const nodeDefault: NodeDefault = { max_read_timeout: 0, max_write_timeout: 0, }, + retry_config: { + retry_enabled: true, + max_retries: 3, + retry_interval: 100, + }, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode diff --git a/web/app/components/workflow/nodes/http/panel.tsx b/web/app/components/workflow/nodes/http/panel.tsx index 5c613aa0f35a41..91b3a6140d8d16 100644 --- a/web/app/components/workflow/nodes/http/panel.tsx +++ b/web/app/components/workflow/nodes/http/panel.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import React from 'react' +import { memo } from 'react' import { useTranslation } from 'react-i18next' import useConfig from './use-config' import ApiInput from './components/api-input' @@ -18,6 +18,7 @@ import { FileArrow01 } from '@/app/components/base/icons/src/vender/line/files' import type { NodePanelProps } from '@/app/components/workflow/types' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import ResultPanel from '@/app/components/workflow/run/result-panel' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' const i18nPrefix = 'workflow.nodes.http' @@ -60,6 +61,10 @@ const Panel: FC> = ({ hideCurlPanel, handleCurlImport, } = useConfig(id, data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() // To prevent prompt editor in body not update data. if (!isDataReady) return null @@ -181,6 +186,7 @@ const Panel: FC> = ({ {isShowSingleRun && ( > = ({ runningStatus={runningStatus} onRun={handleRun} onStop={handleStop} - result={} + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )} {(isShowCurlPanel && !readOnly) && ( @@ -207,4 +215,4 @@ const Panel: FC> = ({ ) } -export default React.memo(Panel) +export default memo(Panel) diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 21ef6395b13558..60f68d93e21270 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -19,6 +19,7 @@ import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/c import ResultPanel from '@/app/components/workflow/run/result-panel' import Tooltip from '@/app/components/base/tooltip' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' const i18nPrefix = 'workflow.nodes.llm' @@ -69,6 +70,10 @@ const Panel: FC> = ({ runResult, filterJinjia2InputVar, } = useConfig(id, data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() const model = inputs.model @@ -282,12 +287,15 @@ const Panel: FC> = ({ {isShowSingleRun && ( } + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )} diff --git a/web/app/components/workflow/nodes/tool/panel.tsx b/web/app/components/workflow/nodes/tool/panel.tsx index 49e645faa48dfe..d0d4c3a83946a4 100644 --- a/web/app/components/workflow/nodes/tool/panel.tsx +++ b/web/app/components/workflow/nodes/tool/panel.tsx @@ -14,6 +14,8 @@ import Loading from '@/app/components/base/loading' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import ResultPanel from '@/app/components/workflow/run/result-panel' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' +import { useToolIcon } from '@/app/components/workflow/hooks' const i18nPrefix = 'workflow.nodes.tool' @@ -48,6 +50,11 @@ const Panel: FC> = ({ handleStop, runResult, } = useConfig(id, data) + const toolIcon = useToolIcon(data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() if (isLoading) { return
@@ -143,12 +150,16 @@ const Panel: FC> = ({ {isShowSingleRun && ( } + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )}
diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 5d932a1ba21608..ebd5e7a99df7df 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -27,6 +27,7 @@ import { getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' import type { FileEntity } from '@/app/components/base/file-uploader/types' +import type { NodeTracing } from '@/types/workflow' type GetAbortController = (abortController: AbortController) => void type SendCallback = { @@ -381,6 +382,28 @@ export const useChat = ( } })) }, + onNodeRetry: ({ data }) => { + if (data.iteration_id) + return + + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id) + }) + if (responseItem.workflowProcess!.tracing[currentIndex].retryDetail) + responseItem.workflowProcess!.tracing[currentIndex].retryDetail?.push(data as NodeTracing) + else + responseItem.workflowProcess!.tracing[currentIndex].retryDetail = [data as NodeTracing] + + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) + }, onNodeFinished: ({ data }) => { if (data.iteration_id) return @@ -394,6 +417,9 @@ export const useChat = ( ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } : {}), + ...(responseItem.workflowProcess!.tracing[currentIndex]?.retryDetail + ? { retryDetail: responseItem.workflowProcess!.tracing[currentIndex].retryDetail } + : {}), ...data, } as any handleUpdateChatList(produce(chatListRef.current, (draft) => { diff --git a/web/app/components/workflow/panel/workflow-preview.tsx b/web/app/components/workflow/panel/workflow-preview.tsx index 2139ebd33881e5..210a95f1f80ee6 100644 --- a/web/app/components/workflow/panel/workflow-preview.tsx +++ b/web/app/components/workflow/panel/workflow-preview.tsx @@ -25,6 +25,7 @@ import { import { SimpleBtn } from '../../app/text-generate/item' import Toast from '../../base/toast' import IterationResultPanel from '../run/iteration-result-panel' +import RetryResultPanel from '../run/retry-result-panel' import InputsPanel from './inputs-panel' import cn from '@/utils/classnames' import Loading from '@/app/components/base/loading' @@ -53,11 +54,16 @@ const WorkflowPreview = () => { }, [workflowRunningData]) const [iterationRunResult, setIterationRunResult] = useState([]) + const [retryRunResult, setRetryRunResult] = useState([]) const [iterDurationMap, setIterDurationMap] = useState({}) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) + const [isShowRetryDetail, { + setTrue: doShowRetryDetail, + setFalse: doHideRetryDetail, + }] = useBoolean(false) const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterationDurationMap: IterationDurationMap) => { setIterDurationMap(iterationDurationMap) @@ -65,6 +71,11 @@ const WorkflowPreview = () => { doShowIterationDetail() }, [doShowIterationDetail]) + const handleRetryDetail = useCallback((detail: NodeTracing[]) => { + setRetryRunResult(detail) + doShowRetryDetail() + }, [doShowRetryDetail]) + if (isShowIterationDetail) { return (
{
)} - {currentTab === 'TRACING' && ( + {currentTab === 'TRACING' && !isShowRetryDetail && ( )} {currentTab === 'TRACING' && !workflowRunningData?.tracing?.length && ( @@ -213,7 +225,14 @@ const WorkflowPreview = () => { )} - + { + currentTab === 'TRACING' && isShowRetryDetail && ( + + ) + } )} diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 2bf705f4ced683..520c59bf4c682d 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -9,6 +9,7 @@ import OutputPanel from './output-panel' import ResultPanel from './result-panel' import TracingPanel from './tracing-panel' import IterationResultPanel from './iteration-result-panel' +import RetryResultPanel from './retry-result-panel' import cn from '@/utils/classnames' import { ToastContext } from '@/app/components/base/toast' import Loading from '@/app/components/base/loading' @@ -107,6 +108,18 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const processNonIterationNode = (item: NodeTracing) => { const { execution_metadata } = item if (!execution_metadata?.iteration_id) { + if (item.status === 'retry') { + const retryNode = result.find(node => node.node_id === item.node_id) + + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + + return + } result.push(item) return } @@ -181,10 +194,15 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const [iterationRunResult, setIterationRunResult] = useState([]) const [iterDurationMap, setIterDurationMap] = useState({}) + const [retryRunResult, setRetryRunResult] = useState([]) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) + const [isShowRetryDetail, { + setTrue: doShowRetryDetail, + setFalse: doHideRetryDetail, + }] = useBoolean(false) const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => { setIterationRunResult(detail) @@ -192,6 +210,11 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe setIterDurationMap(iterDurationMap) }, [doShowIterationDetail, setIterationRunResult, setIterDurationMap]) + const handleShowRetryDetail = useCallback((detail: NodeTracing[]) => { + setRetryRunResult(detail) + doShowRetryDetail() + }, [doShowRetryDetail, setRetryRunResult]) + if (isShowIterationDetail) { return (
@@ -261,13 +284,22 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe exceptionCounts={runDetail.exceptions_count} /> )} - {!loading && currentTab === 'TRACING' && ( + {!loading && currentTab === 'TRACING' && !isShowRetryDetail && ( )} + { + !loading && currentTab === 'TRACING' && isShowRetryDetail && ( + + ) + }
) diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index d1a02ecfe007a7..bb07bd1e8c61b8 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -8,6 +8,7 @@ import { RiCheckboxCircleFill, RiErrorWarningLine, RiLoader2Line, + RiRestartFill, } from '@remixicon/react' import BlockIcon from '../block-icon' import { BlockEnum } from '../types' @@ -20,6 +21,7 @@ import Button from '@/app/components/base/button' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import type { IterationDurationMap, NodeTracing } from '@/types/workflow' import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip' +import { hasRetryNode } from '@/app/components/workflow/utils' type Props = { className?: string @@ -28,8 +30,10 @@ type Props = { hideInfo?: boolean hideProcessDetail?: boolean onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void + onShowRetryDetail?: (detail: NodeTracing[]) => void notShowIterationNav?: boolean justShowIterationNavArrow?: boolean + justShowRetryNavArrow?: boolean } const NodePanel: FC = ({ @@ -39,6 +43,7 @@ const NodePanel: FC = ({ hideInfo = false, hideProcessDetail, onShowIterationDetail, + onShowRetryDetail, notShowIterationNav, justShowIterationNavArrow, }) => { @@ -88,11 +93,17 @@ const NodePanel: FC = ({ }, [nodeInfo.expand, setCollapseState]) const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration + const isRetryNode = hasRetryNode(nodeInfo.node_type) && nodeInfo.retryDetail const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() onShowIterationDetail?.(nodeInfo.details || [], nodeInfo?.iterDurationMap || nodeInfo.execution_metadata?.iteration_duration_map || {}) } + const handleOnShowRetryDetail = (e: React.MouseEvent) => { + e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() + onShowRetryDetail?.(nodeInfo.retryDetail || []) + } return (
@@ -169,6 +180,19 @@ const NodePanel: FC = ({
)} + {isRetryNode && ( + + )}
{(nodeInfo.status === 'stopped') && ( diff --git a/web/app/components/workflow/run/result-panel.tsx b/web/app/components/workflow/run/result-panel.tsx index a688693e4fe9df..bbe740ad48ecfe 100644 --- a/web/app/components/workflow/run/result-panel.tsx +++ b/web/app/components/workflow/run/result-panel.tsx @@ -1,11 +1,17 @@ 'use client' import type { FC } from 'react' import { useTranslation } from 'react-i18next' +import { + RiArrowRightSLine, + RiRestartFill, +} from '@remixicon/react' import StatusPanel from './status' import MetaData from './meta' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip' +import type { NodeTracing } from '@/types/workflow' +import Button from '@/app/components/base/button' type ResultPanelProps = { inputs?: string @@ -22,6 +28,8 @@ type ResultPanelProps = { showSteps?: boolean exceptionCounts?: number execution_metadata?: any + retry_events?: NodeTracing[] + onShowRetryDetail?: (retries: NodeTracing[]) => void } const ResultPanel: FC = ({ @@ -38,8 +46,11 @@ const ResultPanel: FC = ({ showSteps, exceptionCounts, execution_metadata, + retry_events, + onShowRetryDetail, }) => { const { t } = useTranslation() + return (
@@ -51,6 +62,23 @@ const ResultPanel: FC = ({ exceptionCounts={exceptionCounts} />
+ { + retry_events?.length && onShowRetryDetail && ( +
+ +
+ ) + }
void +} + +const RetryResultPanel: FC = ({ + list, + onBack, +}) => { + const { t } = useTranslation() + + return ( +
+
{ + e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() + onBack() + }} + > + + {t('workflow.singleRun.back')} +
+ ({ + ...item, + title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`, + }))} + className='bg-background-section-burn' + /> +
+ ) +} +export default memo(RetryResultPanel) diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index 57b3a5cf5f145d..ad7897189592cf 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -21,6 +21,7 @@ import type { IterationDurationMap, NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void + onShowRetryDetail?: (detail: NodeTracing[]) => void className?: string hideNodeInfo?: boolean hideNodeProcessDetail?: boolean @@ -160,6 +161,7 @@ function buildLogTree(nodes: NodeTracing[], t: (key: string) => string): Tracing const TracingPanel: FC = ({ list, onShowIterationDetail, + onShowRetryDetail, className, hideNodeInfo = false, hideNodeProcessDetail = false, @@ -251,7 +253,9 @@ const TracingPanel: FC = ({ diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index c40ea0de551b4b..6d0fabd90ef8c1 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -13,6 +13,7 @@ import type { DefaultValueForm, ErrorHandleTypeEnum, } from '@/app/components/workflow/nodes/_base/components/error-handle/types' +import type { WorkflowRetryConfig } from '@/app/components/workflow/nodes/_base/components/retry/types' export enum BlockEnum { Start = 'start', @@ -68,6 +69,7 @@ export type CommonNodeType = { _iterationIndex?: number _inParallelHovering?: boolean _waitingRun?: boolean + _retryIndex?: number isInIteration?: boolean iteration_id?: string selected?: boolean @@ -77,6 +79,7 @@ export type CommonNodeType = { width?: number height?: number error_strategy?: ErrorHandleTypeEnum + retry_config?: WorkflowRetryConfig default_value?: DefaultValueForm[] } & T & Partial> @@ -293,6 +296,7 @@ export enum NodeRunningStatus { Succeeded = 'succeeded', Failed = 'failed', Exception = 'exception', + Retry = 'retry', } export type OnNodeAdd = ( diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index abe129e6d60f72..4c61267e4cfc2a 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -26,6 +26,8 @@ import { } from './types' import { CUSTOM_NODE, + DEFAULT_RETRY_INTERVAL, + DEFAULT_RETRY_MAX, ITERATION_CHILDREN_Z_INDEX, ITERATION_NODE_Z_INDEX, NODE_WIDTH_X_OFFSET, @@ -279,6 +281,14 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated } + if (node.data.type === BlockEnum.HttpRequest && !node.data.retry_config) { + node.data.retry_config = { + retry_enabled: true, + max_retries: DEFAULT_RETRY_MAX, + retry_interval: DEFAULT_RETRY_INTERVAL, + } + } + return node }) } @@ -797,3 +807,7 @@ export const isExceptionVariable = (variable: string, nodeType?: BlockEnum) => { return false } + +export const hasRetryNode = (nodeType?: BlockEnum) => { + return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code +} diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index e2a2fdb59d8a91..fab25fa50958e2 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -329,6 +329,20 @@ const translation = { tip: 'There are {{num}} nodes in the process running abnormally, please go to tracing to check the logs.', }, }, + retry: { + retry: 'Retry', + retryOnFailure: 'retry on failure', + maxRetries: 'max retries', + retryInterval: 'retry interval', + retryTimes: 'Retry {{times}} times on failure', + retrying: 'Retrying...', + retrySuccessful: 'Retry successful', + retryFailed: 'Retry failed', + retryFailedTimes: '{{times}} retries failed', + times: 'times', + ms: 'ms', + retries: '{{num}} Retries', + }, }, start: { required: 'required', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 19cda330579dcd..dfad9208e73f4b 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -329,6 +329,20 @@ const translation = { tip: '流程中有 {{num}} 个节点运行异常,请前往追踪查看日志。', }, }, + retry: { + retry: '重试', + retryOnFailure: '失败时重试', + maxRetries: '最大重试次数', + retryInterval: '重试间隔', + retryTimes: '失败时重试 {{times}} 次', + retrying: '重试中...', + retrySuccessful: '重试成功', + retryFailed: '重试失败', + retryFailedTimes: '{{times}} 次重试失败', + times: '次', + ms: '毫秒', + retries: '{{num}} 重试次数', + }, }, start: { required: '必填', diff --git a/web/service/base.ts b/web/service/base.ts index 03421d92a4cb6f..22b1a43ad18359 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -62,6 +62,7 @@ export type IOnNodeStarted = (nodeStarted: NodeStartedResponse) => void export type IOnNodeFinished = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationStarted = (workflowStarted: IterationStartedResponse) => void export type IOnIterationNext = (workflowStarted: IterationNextResponse) => void +export type IOnNodeRetry = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationFinished = (workflowFinished: IterationFinishedResponse) => void export type IOnParallelBranchStarted = (parallelBranchStarted: ParallelBranchStartedResponse) => void export type IOnParallelBranchFinished = (parallelBranchFinished: ParallelBranchFinishedResponse) => void @@ -92,6 +93,7 @@ export type IOtherOptions = { onIterationStart?: IOnIterationStarted onIterationNext?: IOnIterationNext onIterationFinish?: IOnIterationFinished + onNodeRetry?: IOnNodeRetry onParallelBranchStarted?: IOnParallelBranchStarted onParallelBranchFinished?: IOnParallelBranchFinished onTextChunk?: IOnTextChunk @@ -165,6 +167,7 @@ const handleStream = ( onIterationStart?: IOnIterationStarted, onIterationNext?: IOnIterationNext, onIterationFinish?: IOnIterationFinished, + onNodeRetry?: IOnNodeRetry, onParallelBranchStarted?: IOnParallelBranchStarted, onParallelBranchFinished?: IOnParallelBranchFinished, onTextChunk?: IOnTextChunk, @@ -256,6 +259,9 @@ const handleStream = ( else if (bufferObj.event === 'iteration_completed') { onIterationFinish?.(bufferObj as IterationFinishedResponse) } + else if (bufferObj.event === 'node_retry') { + onNodeRetry?.(bufferObj as NodeFinishedResponse) + } else if (bufferObj.event === 'parallel_branch_started') { onParallelBranchStarted?.(bufferObj as ParallelBranchStartedResponse) } @@ -462,6 +468,7 @@ export const ssePost = ( onIterationStart, onIterationNext, onIterationFinish, + onNodeRetry, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, @@ -533,7 +540,7 @@ export const ssePost = ( return } onData?.(str, isFirstMessage, moreInfo) - }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) + }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onNodeRetry, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) }).catch((e) => { if (e.toString() !== 'AbortError: The user aborted a request.' && !e.toString().errorMessage.includes('TypeError: Cannot assign to read only property')) Toast.notify({ type: 'error', message: e }) diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 38f0bb5a409fd9..cd6e9cfa5f02ac 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -52,10 +52,12 @@ export type NodeTracing = { extras?: any expand?: boolean // for UI details?: NodeTracing[][] // iteration detail + retryDetail?: NodeTracing[] // retry detail parallel_id?: string parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + retry_index?: number } export type FetchWorkflowDraftResponse = { @@ -178,6 +180,7 @@ export type NodeFinishedResponse = { } created_at: number files?: FileResponse[] + retry_index?: number } } From 4211b9abbd4ba8cd0aa559280219cfbdfca7f990 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 16:12:01 +0800 Subject: [PATCH 27/91] chore: translate i18n files (#11892) Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com> --- web/i18n/de-DE/workflow.ts | 14 ++++++++++++++ web/i18n/es-ES/workflow.ts | 14 ++++++++++++++ web/i18n/fa-IR/workflow.ts | 14 ++++++++++++++ web/i18n/fr-FR/workflow.ts | 14 ++++++++++++++ web/i18n/hi-IN/workflow.ts | 14 ++++++++++++++ web/i18n/it-IT/workflow.ts | 14 ++++++++++++++ web/i18n/ja-JP/workflow.ts | 14 ++++++++++++++ web/i18n/ko-KR/workflow.ts | 14 ++++++++++++++ web/i18n/pl-PL/workflow.ts | 14 ++++++++++++++ web/i18n/pt-BR/workflow.ts | 14 ++++++++++++++ web/i18n/ro-RO/workflow.ts | 14 ++++++++++++++ web/i18n/ru-RU/workflow.ts | 14 ++++++++++++++ web/i18n/sl-SI/workflow.ts | 14 ++++++++++++++ web/i18n/th-TH/workflow.ts | 14 ++++++++++++++ web/i18n/tr-TR/workflow.ts | 14 ++++++++++++++ web/i18n/uk-UA/workflow.ts | 14 ++++++++++++++ web/i18n/vi-VN/workflow.ts | 14 ++++++++++++++ web/i18n/zh-Hant/workflow.ts | 14 ++++++++++++++ 18 files changed, 252 insertions(+) diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index 8888e237393f2f..38686f8c1d485d 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Fehlerbehandlung', tip: 'Ausnahmebehandlungsstrategie, die ausgelöst wird, wenn ein Knoten auf eine Ausnahme stößt.', }, + retry: { + retry: 'Wiederholen', + retryOnFailure: 'Wiederholen bei Fehler', + maxRetries: 'Max. Wiederholungen', + retryInterval: 'Wiederholungsintervall', + retryTimes: 'Wiederholen Sie {{times}} mal bei einem Fehler', + retrying: 'Wiederholung...', + retrySuccessful: 'Wiederholen erfolgreich', + retryFailed: 'Wiederholung fehlgeschlagen', + retryFailedTimes: '{{times}} fehlgeschlagene Wiederholungen', + times: 'mal', + ms: 'Frau', + retries: '{{num}} Wiederholungen', + }, }, start: { required: 'erforderlich', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index c49c611da87c50..d112ad97b6f9a9 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Manejo de errores', tip: 'Estrategia de control de excepciones, que se desencadena cuando un nodo encuentra una excepción.', }, + retry: { + retryOnFailure: 'Volver a intentarlo en caso de error', + maxRetries: 'Número máximo de reintentos', + retryInterval: 'Intervalo de reintento', + retryTimes: 'Reintentar {{times}} veces en caso de error', + retrying: 'Reintentando...', + retrySuccessful: 'Volver a intentarlo correctamente', + retryFailed: 'Error en el reintento', + retryFailedTimes: '{{veces}} reintentos fallidos', + times: 'veces', + ms: 'Sra.', + retries: '{{num}} Reintentos', + retry: 'Reintentar', + }, }, start: { required: 'requerido', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index c29f9115565781..37cba2f16b97c7 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'مدیریت خطا', tip: 'استراتژی مدیریت استثنا، زمانی که یک گره با یک استثنا مواجه می شود، فعال می شود.', }, + retry: { + times: 'بار', + retryInterval: 'فاصله تلاش مجدد', + retryOnFailure: 'در مورد شکست دوباره امتحان کنید', + ms: 'خانم', + retry: 'دوباره', + retries: '{{عدد}} تلاش های مجدد', + maxRetries: 'حداکثر تلاش مجدد', + retrying: 'تلاش مجدد...', + retryFailed: 'تلاش مجدد ناموفق بود', + retryTimes: '{{times}} بار در صورت شکست دوباره امتحان کنید', + retrySuccessful: 'امتحان مجدد با موفقیت انجام دهید', + retryFailedTimes: '{{بار}} تلاش های مجدد ناموفق بود', + }, }, start: { required: 'الزامی', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index a2b24061132579..e7d2802cb4f805 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Gestion des erreurs', tip: 'Stratégie de gestion des exceptions, déclenchée lorsqu’un nœud rencontre une exception.', }, + retry: { + retry: 'Réessayer', + retryOnFailure: 'Réessai en cas d’échec', + maxRetries: 'Nombre maximal de tentatives', + retryInterval: 'intervalle de nouvelle tentative', + retryTimes: 'Réessayez {{times}} fois en cas d’échec', + retrying: 'Réessayer...', + retrySuccessful: 'Réessai réussi', + retryFailed: 'Échec de la nouvelle tentative', + retryFailedTimes: '{{times}} les tentatives ont échoué', + times: 'fois', + ms: 'ms', + retries: '{{num}} Tentatives', + }, }, start: { required: 'requis', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 47589078ce3175..619abee12826eb 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -334,6 +334,20 @@ const translation = { title: 'त्रुटि हैंडलिंग', tip: 'अपवाद हैंडलिंग रणनीति, ट्रिगर जब एक नोड एक अपवाद का सामना करता है।', }, + retry: { + times: 'गुणा', + ms: 'सुश्री', + retryInterval: 'अंतराल का पुनः प्रयास करें', + retrying: 'पुनर्प्रयास।।।', + retryFailed: 'पुनः प्रयास विफल रहा', + retryFailedTimes: '{{times}} पुनः प्रयास विफल रहे', + retryTimes: 'विफलता पर {{times}} बार पुनः प्रयास करें', + retries: '{{num}} पुनर्प्रयास', + maxRetries: 'अधिकतम पुनः प्रयास करता है', + retrySuccessful: 'पुनः प्रयास सफल', + retry: 'पुनर्प्रयास', + retryOnFailure: 'विफलता पर पुनः प्रयास करें', + }, }, start: { required: 'आवश्यक', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index e760074e6a526d..f4390580d5efe7 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -337,6 +337,20 @@ const translation = { title: 'Gestione degli errori', tip: 'Strategia di gestione delle eccezioni, attivata quando un nodo rileva un\'eccezione.', }, + retry: { + retry: 'Ripetere', + retryOnFailure: 'Riprova in caso di errore', + maxRetries: 'Numero massimo di tentativi', + retryInterval: 'Intervallo tentativi', + retryTimes: 'Riprova {{times}} volte in caso di errore', + retrying: 'Riprovare...', + retryFailedTimes: '{{times}} tentativi falliti', + times: 'tempi', + retries: '{{num}} Tentativi', + retrySuccessful: 'Riprova riuscito', + retryFailed: 'Nuovo tentativo non riuscito', + ms: 'ms', + }, }, start: { required: 'richiesto', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index 8305105c22181d..1aa764a19f5a3d 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'エラー処理', tip: 'ノードが例外を検出したときにトリガーされる例外処理戦略。', }, + retry: { + retry: 'リトライ', + retryOnFailure: '失敗時の再試行', + maxRetries: '最大再試行回数', + retryInterval: '再試行間隔', + retrying: '再試行。。。', + retryFailed: '再試行に失敗しました', + times: '倍', + ms: 'さん', + retryTimes: '失敗時に{{times}}回再試行', + retrySuccessful: '再試行に成功しました', + retries: '{{num}} 回の再試行', + retryFailedTimes: '{{times}}回のリトライが失敗しました', + }, }, start: { required: '必須', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index cc2c1b1a282b11..4a4d2f91938004 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: '오류 처리', tip: '노드에 예외가 발생할 때 트리거되는 예외 처리 전략입니다.', }, + retry: { + retry: '재시도', + retryOnFailure: '실패 시 재시도', + maxRetries: '최대 재시도 횟수', + retryInterval: '재시도 간격', + retryTimes: '실패 시 {{times}}번 재시도', + retrying: '재시도...', + retrySuccessful: '재시도 성공', + retryFailed: '재시도 실패', + retryFailedTimes: '{{times}} 재시도 실패', + times: '배', + ms: '미에스', + retries: '{{숫자}} 재시도', + }, }, start: { required: '필수', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 2db6cf2bfbeb83..13784df603fb7b 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -322,6 +322,20 @@ const translation = { tip: 'Strategia obsługi wyjątków, wyzwalana, gdy węzeł napotka wyjątek.', title: 'Obsługa błędów', }, + retry: { + retry: 'Ponów próbę', + maxRetries: 'Maksymalna liczba ponownych prób', + retryInterval: 'Interwał ponawiania prób', + retryTimes: 'Ponów próbę {{times}} razy w przypadku niepowodzenia', + retrying: 'Ponawianie...', + retrySuccessful: 'Ponawianie próby powiodło się', + retryFailed: 'Ponawianie próby nie powiodło się', + times: 'razy', + retries: '{{liczba}} Ponownych prób', + retryOnFailure: 'Ponawianie próby w przypadku niepowodzenia', + retryFailedTimes: '{{times}} ponawianie prób nie powiodło się', + ms: 'Ms', + }, }, start: { required: 'wymagane', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index 4d53ec07c77fb7..b99c64cdf4e003 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Tratamento de erros', tip: 'Estratégia de tratamento de exceções, disparada quando um nó encontra uma exceção.', }, + retry: { + retry: 'Repetir', + retryOnFailure: 'Tentar novamente em caso de falha', + maxRetries: 'Máximo de tentativas', + retryInterval: 'Intervalo de repetição', + retryTimes: 'Tente novamente {{times}} vezes em caso de falha', + retrying: 'Repetindo...', + retrySuccessful: 'Repetição bem-sucedida', + retryFailed: 'Falha na nova tentativa', + retryFailedTimes: '{{times}} tentativas falharam', + times: 'vezes', + ms: 'ms', + retries: '{{num}} Tentativas', + }, }, start: { required: 'requerido', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index 3dfa6d04edb43b..b142640c9b69a5 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Gestionarea erorilor', tip: 'Strategie de gestionare a excepțiilor, declanșată atunci când un nod întâlnește o excepție.', }, + retry: { + retry: 'Reîncercare', + retryOnFailure: 'Reîncercați în caz de eșec', + maxRetries: 'numărul maxim de încercări', + retryInterval: 'Interval de reîncercare', + retrying: 'Reîncerca...', + retrySuccessful: 'Reîncercați cu succes', + retryFailed: 'Reîncercarea a eșuat', + retryFailedTimes: '{{times}} reîncercări eșuate', + times: 'Ori', + ms: 'Ms', + retries: '{{num}} Încercări', + retryTimes: 'Reîncercați {{times}} ori în caz de eșec', + }, }, start: { required: 'necesar', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 600c59f2edc19b..49c43b4d6d85de 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Обработка ошибок', tip: 'Стратегия обработки исключений, запускаемая при обнаружении исключения на узле.', }, + retry: { + retry: 'Снова пробовать', + retryOnFailure: 'Повторная попытка при неудаче', + maxRetries: 'максимальное количество повторных попыток', + retryInterval: 'Интервал повторных попыток', + retryTimes: 'Повторите {{раз}} раз при неудаче', + retrying: 'Повтор...', + retrySuccessful: 'Повторить попытку успешно', + retryFailed: 'Повторная попытка не удалась', + times: 'раз', + ms: 'госпожа', + retryFailedTimes: 'Повторные попытки {{times}} не увенчались успехом', + retries: '{{число}} Повторных попыток', + }, }, start: { required: 'обязательно', diff --git a/web/i18n/sl-SI/workflow.ts b/web/i18n/sl-SI/workflow.ts index 2c9dab8b559ad8..7c40c25e92187f 100644 --- a/web/i18n/sl-SI/workflow.ts +++ b/web/i18n/sl-SI/workflow.ts @@ -759,6 +759,20 @@ const translation = { title: 'Ravnanje z napakami', tip: 'Strategija ravnanja z izjemami, ki se sproži, ko vozlišče naleti na izjemo.', }, + retry: { + retryOnFailure: 'Ponovni poskus ob neuspehu', + retryInterval: 'Interval ponovnega poskusa', + retrying: 'Ponovnim...', + retry: 'Ponoviti', + retryFailedTimes: '{{times}} ponovni poskusi niso uspeli', + retries: '{{num}} Poskusov', + times: 'Krat', + retryTimes: 'Ponovni poskus {{times}}-krat ob neuspehu', + retryFailed: 'Ponovni poskus ni uspel', + retrySuccessful: 'Ponovni poskus je bil uspešen', + maxRetries: 'Največ ponovnih poskusov', + ms: 'Ms', + }, }, start: { outputVars: { diff --git a/web/i18n/th-TH/workflow.ts b/web/i18n/th-TH/workflow.ts index c4305466aac65d..b8d2e72de0b75b 100644 --- a/web/i18n/th-TH/workflow.ts +++ b/web/i18n/th-TH/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'การจัดการข้อผิดพลาด', tip: 'กลยุทธ์การจัดการข้อยกเว้น ทริกเกอร์เมื่อโหนดพบข้อยกเว้น', }, + retry: { + retry: 'ลอง', + retryOnFailure: 'ลองใหม่เมื่อล้มเหลว', + maxRetries: 'การลองซ้ําสูงสุด', + retryInterval: 'ช่วงเวลาลองใหม่', + retryTimes: 'ลอง {{times}} ครั้งเมื่อล้มเหลว', + retrying: 'กําลังลองซ้ํา...', + retrySuccessful: 'ลองใหม่สําเร็จ', + retryFailed: 'ลองใหม่ล้มเหลว', + retryFailedTimes: '{{times}} การลองซ้ําล้มเหลว', + times: 'ครั้ง', + retries: '{{num}} ลอง', + ms: 'นางสาว', + }, }, start: { required: 'ต้องระบุ', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 951a20e0494221..edec6a0b49d9bd 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Hata İşleme', tip: 'Bir düğüm bir özel durumla karşılaştığında tetiklenen özel durum işleme stratejisi.', }, + retry: { + retry: 'Yeni -den deneme', + retryOnFailure: 'Hata durumunda yeniden dene', + maxRetries: 'En fazla yeniden deneme', + times: 'kere', + retries: '{{sayı}} Yeni -den deneme', + retryFailed: 'Yeniden deneme başarısız oldu', + retryInterval: 'Yeniden deneme aralığı', + retryTimes: 'Hata durumunda {{times}} kez yeniden deneyin', + retryFailedTimes: '{{times}} yeniden denemeleri başarısız oldu', + retrySuccessful: 'Yeniden deneme başarılı', + retrying: 'Yeniden deneniyor...', + ms: 'Ms', + }, }, start: { required: 'gerekli', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 2c00d3bf59a944..29fd9d8188aed4 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Обробка помилок', tip: 'Стратегія обробки винятків, що спрацьовує, коли вузол стикається з винятком.', }, + retry: { + retry: 'Повторити', + retryOnFailure: 'повторити спробу в разі невдачі', + retryInterval: 'Інтервал повторних спроб', + retrying: 'Спроби...', + retryFailed: 'Повторна спроба не вдалася', + times: 'Разів', + ms: 'МС', + retries: '{{num}} Спроб', + maxRetries: 'Максимальна кількість повторних спроб', + retrySuccessful: 'Повторна спроба успішна', + retryFailedTimes: '{{times}} повторні спроби не вдалися', + retryTimes: 'Повторіть спробу {{times}} у разі невдачі', + }, }, start: { required: 'обов\'язковий', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 956fe84159b62d..9e16cb5347e417 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -322,6 +322,20 @@ const translation = { tip: 'Chiến lược xử lý ngoại lệ, được kích hoạt khi một nút gặp phải ngoại lệ.', title: 'Xử lý lỗi', }, + retry: { + retry: 'Thử lại', + maxRetries: 'Số lần thử lại tối đa', + retryInterval: 'Khoảng thời gian thử lại', + retryTimes: 'Thử lại {{lần}} lần khi không thành công', + retrying: 'Thử lại...', + retrySuccessful: 'Thử lại thành công', + retryFailed: 'Thử lại không thành công', + retryFailedTimes: '{{lần}} lần thử lại không thành công', + retries: '{{số}} Thử lại', + retryOnFailure: 'Thử lại khi không thành công', + times: 'lần', + ms: 'Ms', + }, }, start: { required: 'bắt buộc', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 4bbbf7a04f7beb..a78c6a2f04d4ea 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: '錯誤處理', tip: '異常處理策略,當節點遇到異常時觸發。', }, + retry: { + retry: '重試', + retryOnFailure: '失敗時重試', + maxRetries: '最大重試次數', + retryInterval: '重試間隔', + retryTimes: '失敗時重試 {{times}} 次', + retrying: '重試。。。', + retrySuccessful: '重試成功', + retryFailed: '重試失敗', + retryFailedTimes: '{{times}} 次重試失敗', + times: '次', + ms: '女士', + retries: '{{num}}重試', + }, }, start: { required: '必填', From ef7e47d1625c5e678b8bf5da2d253e2fedd9470f Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Fri, 20 Dec 2024 16:12:34 +0800 Subject: [PATCH 28/91] fix: rerank switch (#11897) --- .../components/workflow/nodes/knowledge-retrieval/utils.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index e9da9accccf783..794fcbca4aa216 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -129,9 +129,6 @@ export const getMultipleRetrievalConfig = ( reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } - if (!rerankModelIsValid) - result.reranking_model = undefined - const setDefaultWeights = () => { result.weights = { vector_setting: { @@ -198,7 +195,6 @@ export const getMultipleRetrievalConfig = ( setDefaultWeights() } } - if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { result.reranking_mode = RerankingModeEnum.WeightedScore setDefaultWeights() From ac635c70cd0a7de8cb959ed89be9cadf7420329f Mon Sep 17 00:00:00 2001 From: yihong Date: Fri, 20 Dec 2024 17:19:46 +0800 Subject: [PATCH 29/91] fix: doc can not extract tables (#11879) Signed-off-by: yihong0618 Co-authored-by: akinobu-i --- .../workflow/nodes/document_extractor/node.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 95d0ea3aab54d6..6d82dbe6d70da3 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -1,6 +1,7 @@ import csv import io import json +import logging import os import tempfile @@ -22,6 +23,8 @@ from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError +logger = logging.getLogger(__name__) + class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): """ @@ -177,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str: def _extract_text_from_doc(file_content: bytes) -> str: + """ + Extract text from a DOC/DOCX file. + For now support only paragraph and table add more if needed + """ try: doc_file = io.BytesIO(file_content) doc = docx.Document(doc_file) - return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + text = [] + # Process paragraphs + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text.append(paragraph.text) + + # Process tables + for table in doc.tables: + # Table header + try: + # table maybe cause errors so ignore it. + if len(table.rows) > 0 and table.rows[0].cells is not None: + # Check if any cell in the table has text + has_content = False + for row in table.rows: + if any(cell.text.strip() for cell in row.cells): + has_content = True + break + + if has_content: + markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" + for row in table.rows[1:]: + markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" + text.append(markdown_table) + except Exception as e: + logger.warning(f"Failed to extract table from DOC/DOCX: {e}") + continue + + return "\n".join(text) except Exception as e: raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e From 2681bafb76dc3488cfc3d6305b9ada6a63597f9c Mon Sep 17 00:00:00 2001 From: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Date: Fri, 20 Dec 2024 19:23:42 +0900 Subject: [PATCH 30/91] fix: handle document fetching from URL in Anthropic LLM model, solving base64 decoding error (#11858) --- api/core/model_runtime/model_providers/anthropic/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index e1d35ff872a82f..c0ea8c6325c845 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -531,7 +531,7 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> "source": { "type": "base64", "media_type": message_content.mime_type, - "data": message_content.data, + "data": message_content.base64_data, }, } sub_messages.append(sub_message_dict) From f53741c5b9a084a6117d0e3f6f6be615fa4c2741 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Fri, 20 Dec 2024 18:24:47 +0800 Subject: [PATCH 31/91] revert: these 2 settings (#11906) Signed-off-by: -LAN- Co-authored-by: -LAN- --- docker/.env.example | 3 --- docker/docker-compose.yaml | 2 -- 2 files changed, 5 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index e8ec246ae20b4f..43e67a8db41254 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -923,6 +923,3 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false # Maximum number of submitted thread count in a ThreadPool for parallel node execution MAX_SUBMIT_COUNT=100 -# Proxy -HTTP_PROXY= -HTTPS_PROXY= diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d0738d93054fcb..99bc14c7171e4b 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -386,8 +386,6 @@ x-shared-env: &shared-api-worker-env CSP_WHITELIST: ${CSP_WHITELIST:-} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100} - HTTP_PROXY: ${HTTP_PROXY:-} - HTTPS_PROXY: ${HTTPS_PROXY:-} services: # API service From 6ded06c6d96678c8499d0d20e51de68964bbfff6 Mon Sep 17 00:00:00 2001 From: zhu-an <70234959+xhdd123321@users.noreply.github.com> Date: Sat, 21 Dec 2024 11:26:17 +0800 Subject: [PATCH 32/91] fix: dataset search-input compostion can't work in chrome (#11907) Co-authored-by: zhaoqingyu.1075 --- web/app/components/base/search-input/index.tsx | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/web/app/components/base/search-input/index.tsx b/web/app/components/base/search-input/index.tsx index 89345fbe3282d3..556a7bdf49296a 100644 --- a/web/app/components/base/search-input/index.tsx +++ b/web/app/components/base/search-input/index.tsx @@ -23,6 +23,7 @@ const SearchInput: FC = ({ const { t } = useTranslation() const [focus, setFocus] = useState(false) const isComposing = useRef(false) + const [internalValue, setInternalValue] = useState(value) return (
= ({ white && '!bg-white hover:!bg-white group-hover:!bg-white placeholder:!text-gray-400', )} placeholder={placeholder || t('common.operation.search')!} - value={value} + value={internalValue} onChange={(e) => { + setInternalValue(e.target.value) if (!isComposing.current) onChange(e.target.value) }} onCompositionStart={() => { isComposing.current = true }} - onCompositionEnd={() => { + onCompositionEnd={(e) => { isComposing.current = false + onChange(e.data) }} onFocus={() => setFocus(true)} onBlur={() => setFocus(false)} @@ -63,7 +66,10 @@ const SearchInput: FC = ({ {value && (
onChange('')} + onClick={() => { + onChange('') + setInternalValue('') + }} >
From 7a00798027324b23acd658862129c509d0b89ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B7=A6=E5=87=8C?= <68994887+Wuhu-dsm@users.noreply.github.com> Date: Sat, 21 Dec 2024 11:54:12 +0800 Subject: [PATCH 33/91] Update input-var-list.tsx (#9987) --- .../workflow/nodes/tool/components/input-var-list.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index db1a32e31928d3..bab7c20d5b785b 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -162,7 +162,7 @@ const InputVarList: FC = ({ readonly={readOnly} isShowNodeName nodeId={nodeId} - value={varInput?.type === VarKindType.constant ? (varInput?.value || '') : (varInput?.value || [])} + value={varInput?.type === VarKindType.constant ? (varInput?.value ?? '') : (varInput?.value ?? [])} onChange={handleNotMixedTypeChange(variable)} onOpen={handleOpen(index)} defaultVarKindType={varInput?.type || (isNumber ? VarKindType.constant : VarKindType.variable)} From de8800f41ac17829562172d4b9a10ecbc2d9b9d3 Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Sat, 21 Dec 2024 13:36:13 +0800 Subject: [PATCH 34/91] add three aws tools (#11905) Co-authored-by: Yuanbo Li --- .../builtin/aws/tools/bedrock_retrieve.py | 115 ++++++ .../builtin/aws/tools/bedrock_retrieve.yaml | 87 ++++ .../provider/builtin/aws/tools/nova_canvas.py | 357 +++++++++++++++++ .../builtin/aws/tools/nova_canvas.yaml | 175 +++++++++ .../provider/builtin/aws/tools/nova_reel.py | 371 ++++++++++++++++++ .../provider/builtin/aws/tools/nova_reel.yaml | 124 ++++++ .../provider/builtin/aws/tools/s3_operator.py | 80 ++++ .../builtin/aws/tools/s3_operator.yaml | 98 +++++ 8 files changed, 1407 insertions(+) create mode 100644 api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py create mode 100644 api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/nova_canvas.py create mode 100644 api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/nova_reel.py create mode 100644 api/core/tools/provider/builtin/aws/tools/nova_reel.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/s3_operator.py create mode 100644 api/core/tools/provider/builtin/aws/tools/s3_operator.yaml diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py new file mode 100644 index 00000000000000..050b468b740c27 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -0,0 +1,115 @@ +import json +import operator +from typing import Any, Optional, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class BedrockRetrieveTool(BuiltinTool): + bedrock_client: Any = None + knowledge_base_id: str = None + topk: int = None + + def _bedrock_retrieve( + self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None + ): + try: + retrieval_query = {"text": query_input} + + retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} + + # 如果有元数据过滤条件,则添加到检索配置中 + if metadata_filter: + retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter + + response = self.bedrock_client.retrieve( + knowledgeBaseId=knowledge_base_id, + retrievalQuery=retrieval_query, + retrievalConfiguration=retrieval_configuration, + ) + + results = [] + for result in response.get("retrievalResults", []): + results.append( + { + "content": result.get("content", {}).get("text", ""), + "score": result.get("score", 0.0), + "metadata": result.get("metadata", {}), + } + ) + + return results + except Exception as e: + raise Exception(f"Error retrieving from knowledge base: {str(e)}") + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + line = 0 + try: + if not self.bedrock_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region) + else: + self.bedrock_client = boto3.client("bedrock-agent-runtime") + + line = 1 + if not self.knowledge_base_id: + self.knowledge_base_id = tool_parameters.get("knowledge_base_id") + if not self.knowledge_base_id: + return self.create_text_message("Please provide knowledge_base_id") + + line = 2 + if not self.topk: + self.topk = tool_parameters.get("topk", 5) + + line = 3 + query = tool_parameters.get("query", "") + if not query: + return self.create_text_message("Please input query") + + # 获取元数据过滤条件(如果存在) + metadata_filter_str = tool_parameters.get("metadata_filter") + metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None + + line = 4 + retrieved_docs = self._bedrock_retrieve( + query_input=query, + knowledge_base_id=self.knowledge_base_id, + num_results=self.topk, + metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + ) + + line = 5 + # Sort results by score in descending order + sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True) + + line = 6 + return [self.create_json_message(res) for res in sorted_docs] + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}, line : {line}") + + def validate_parameters(self, parameters: dict[str, Any]) -> None: + """ + Validate the parameters + """ + if not parameters.get("knowledge_base_id"): + raise ValueError("knowledge_base_id is required") + + if not parameters.get("query"): + raise ValueError("query is required") + + # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) + metadata_filter_str = parameters.get("metadata_filter") + if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): + raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml new file mode 100644 index 00000000000000..9e51d52def4037 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -0,0 +1,87 @@ +identity: + name: bedrock_retrieve + author: AWS + label: + en_US: Bedrock Retrieve + zh_Hans: Bedrock检索 + pt_BR: Bedrock Retrieve + icon: icon.svg + +description: + human: + en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明 + pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. + llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + +parameters: + - name: knowledge_base_id + type: string + required: true + label: + en_US: Bedrock Knowledge Base ID + zh_Hans: Bedrock知识库ID + pt_BR: Bedrock Knowledge Base ID + human_description: + en_US: ID of the Bedrock Knowledge Base to retrieve from + zh_Hans: 用于检索的Bedrock知识库ID + pt_BR: ID of the Bedrock Knowledge Base to retrieve from + llm_description: ID of the Bedrock Knowledge Base to retrieve from + form: form + + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + pt_BR: Query string + human_description: + en_US: The search query to retrieve relevant information + zh_Hans: 用于检索相关信息的查询语句 + pt_BR: The search query to retrieve relevant information + llm_description: The search query to retrieve relevant information + form: llm + + - name: topk + type: number + required: false + form: form + label: + en_US: Limit for results count + zh_Hans: 返回结果数量限制 + pt_BR: Limit for results count + human_description: + en_US: Maximum number of results to return + zh_Hans: 最大返回结果数量 + pt_BR: Maximum number of results to return + min: 1 + max: 10 + default: 5 + + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS 区域 + pt_BR: AWS Region + human_description: + en_US: AWS region where the Bedrock Knowledge Base is located + zh_Hans: Bedrock知识库所在的AWS区域 + pt_BR: AWS region where the Bedrock Knowledge Base is located + llm_description: AWS region where the Bedrock Knowledge Base is located + form: form + + - name: metadata_filter + type: string + required: false + label: + en_US: Metadata Filter + zh_Hans: 元数据过滤器 + pt_BR: Metadata Filter + human_description: + en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})' + pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.py b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py new file mode 100644 index 00000000000000..954dbe35a4a784 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py @@ -0,0 +1,357 @@ +import base64 +import json +import logging +import re +from datetime import datetime +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class NovaCanvasTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Canvas model for image generation + """ + # Get common parameters + prompt = tool_parameters.get("prompt", "") + image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip() + if not prompt: + return self.create_text_message("Please provide a text prompt for image generation.") + if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3": + return self.create_text_message("Please provide an valid S3 URI for image output.") + + task_type = tool_parameters.get("task_type", "TEXT_IMAGE") + aws_region = tool_parameters.get("aws_region", "us-east-1") + + # Get common image generation config parameters + width = tool_parameters.get("width", 1024) + height = tool_parameters.get("height", 1024) + cfg_scale = tool_parameters.get("cfg_scale", 8.0) + negative_prompt = tool_parameters.get("negative_prompt", "") + seed = tool_parameters.get("seed", 0) + quality = tool_parameters.get("quality", "standard") + + # Handle S3 image if provided + image_input_s3uri = tool_parameters.get("image_input_s3uri", "") + if task_type != "TEXT_IMAGE": + if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3": + return self.create_text_message("Please provide a valid S3 URI for image to image generation.") + + # Parse S3 URI + parsed_uri = urlparse(image_input_s3uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + # Initialize S3 client and download image + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + image_data = response["Body"].read() + + # Base64 encode the image + input_image = base64.b64encode(image_data).decode("utf-8") + + try: + # Initialize Bedrock client + bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region) + + # Base image generation config + image_generation_config = { + "width": width, + "height": height, + "cfgScale": cfg_scale, + "seed": seed, + "numberOfImages": 1, + "quality": quality, + } + + # Prepare request body based on task type + body = {"imageGenerationConfig": image_generation_config} + + if task_type == "TEXT_IMAGE": + body["taskType"] = "TEXT_IMAGE" + body["textToImageParams"] = {"text": prompt} + if negative_prompt: + body["textToImageParams"]["negativeText"] = negative_prompt + + elif task_type == "COLOR_GUIDED_GENERATION": + colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680") + if not self._validate_color_string(colors): + return self.create_text_message("Please provide valid colors in hexadecimal format.") + + body["taskType"] = "COLOR_GUIDED_GENERATION" + body["colorGuidedGenerationParams"] = { + "colors": colors.split("-"), + "referenceImage": input_image, + "text": prompt, + } + if negative_prompt: + body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt + + elif task_type == "IMAGE_VARIATION": + similarity_strength = tool_parameters.get("similarity_strength", 0.5) + + body["taskType"] = "IMAGE_VARIATION" + body["imageVariationParams"] = { + "images": [input_image], + "similarityStrength": similarity_strength, + "text": prompt, + } + if negative_prompt: + body["imageVariationParams"]["negativeText"] = negative_prompt + + elif task_type == "INPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image inpainting.") + + body["taskType"] = "INPAINTING" + body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt} + if negative_prompt: + body["inPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "OUTPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image outpainting.") + outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT") + + body["taskType"] = "OUTPAINTING" + body["outPaintingParams"] = { + "image": input_image, + "maskPrompt": mask_prompt, + "outPaintingMode": outpainting_mode, + "text": prompt, + } + if negative_prompt: + body["outPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "BACKGROUND_REMOVAL": + body["taskType"] = "BACKGROUND_REMOVAL" + body["backgroundRemovalParams"] = {"image": input_image} + + else: + return self.create_text_message(f"Unsupported task type: {task_type}") + + # Call Nova Canvas model + response = bedrock.invoke_model( + body=json.dumps(body), + modelId="amazon.nova-canvas-v1:0", + accept="application/json", + contentType="application/json", + ) + + # Process response + response_body = json.loads(response.get("body").read()) + if response_body.get("error"): + raise Exception(f"Error in model response: {response_body.get('error')}") + base64_image = response_body.get("images")[0] + + # Upload to S3 if image_output_s3uri is provided + try: + # Parse S3 URI for output + parsed_uri = urlparse(image_output_s3uri) + output_bucket = parsed_uri.netloc + output_base_path = parsed_uri.path.lstrip("/") + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_key = f"{output_base_path}/canvas-output-{timestamp}.png" + + # Initialize S3 client if not already done + s3_client = boto3.client("s3", region_name=aws_region) + + # Decode base64 image and upload to S3 + image_data = base64.b64decode(base64_image) + s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png") + logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}") + except Exception as e: + logger.exception("Failed to upload image to S3") + # Return image + return [ + self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"), + self.create_blob_message( + blob=base64.b64decode(base64_image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ), + ] + + except Exception as e: + return self.create_text_message(f"Failed to generate image: {str(e)}") + + def _validate_color_string(self, color_string) -> bool: + color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$" + + if re.match(color_pattern, color_string): + return True + return False + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the image you want to generate or modify", + zh_Hans="您想要生成或修改的图像的文本描述", + ), + llm_description="Describe the image you want to generate or how you want to modify the input image", + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"), + ), + ToolParameter( + name="image_output_s3uri", + label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI" + ), + ), + ToolParameter( + name="width", + label=I18nObject(en_US="Width", zh_Hans="宽度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"), + ), + ToolParameter( + name="height", + label=I18nObject(en_US="Height", zh_Hans="高度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"), + ), + ToolParameter( + name="cfg_scale", + label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=8.0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词" + ), + ), + ToolParameter( + name="negative_prompt", + label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容" + ), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="us-east-1", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="task_type", + label=I18nObject(en_US="Task Type", zh_Hans="任务类型"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="TEXT_IMAGE", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"), + ), + ToolParameter( + name="quality", + label=I18nObject(en_US="Quality", zh_Hans="质量"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="standard", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)" + ), + ), + ToolParameter( + name="colors", + label=I18nObject(en_US="Colors", zh_Hans="颜色"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680", + zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680", + ), + ), + ToolParameter( + name="similarity_strength", + label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0.5, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How similar the generated image should be to the input image (0.0 to 1.0)", + zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)", + ), + ), + ToolParameter( + name="mask_prompt", + label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description to generate mask for inpainting/outpainting", + zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述", + ), + ), + ToolParameter( + name="outpainting_mode", + label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="DEFAULT", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Mode for outpainting (DEFAULT or other supported modes)", + zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml new file mode 100644 index 00000000000000..a72fd9c8efcce1 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml @@ -0,0 +1,175 @@ +identity: + name: nova_canvas + author: AWS + label: + en_US: AWS Bedrock Nova Canvas + zh_Hans: AWS Bedrock Nova Canvas + icon: icon.svg +description: + human: + en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。 + llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. +parameters: + - name: task_type + type: string + required: false + default: TEXT_IMAGE + label: + en_US: Task Type + zh_Hans: 任务类型 + human_description: + en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL) + zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除) + form: llm + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the image you want to generate or modify + zh_Hans: 您想要生成或修改的图像的文本描述 + llm_description: Describe the image you want to generate or how you want to modify the input image + form: llm + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input image s3 uri + zh_Hans: 输入图片的s3 uri + human_description: + en_US: The input image to modify (required for all modes except TEXT_IMAGE) + zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要) + llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE. + form: llm + - name: image_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png + zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传 + llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename. + form: form + - name: negative_prompt + type: string + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + human_description: + en_US: Things you don't want in the generated image + zh_Hans: 您不想在生成的图像中出现的内容 + form: llm + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: 宽度 + human_description: + en_US: Width of the generated image + zh_Hans: 生成图像的宽度 + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: 高度 + human_description: + en_US: Height of the generated image + zh_Hans: 生成图像的高度 + form: form + default: 1024 + - name: cfg_scale + type: number + required: false + label: + en_US: CFG Scale + zh_Hans: CFG比例 + human_description: + en_US: How strongly the image should conform to the prompt + zh_Hans: 图像应该多大程度上符合提示词 + form: form + default: 8.0 + - name: seed + type: number + required: false + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for image generation + zh_Hans: 图像生成的随机种子 + form: form + default: 0 + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + - name: quality + type: string + required: false + default: standard + label: + en_US: Quality + zh_Hans: 质量 + human_description: + en_US: Quality of the generated image (standard or premium) + zh_Hans: 生成图像的质量(标准或高级) + form: form + - name: colors + type: string + required: false + label: + en_US: Colors + zh_Hans: 颜色 + human_description: + en_US: List of colors for color-guided generation + zh_Hans: 颜色引导生成的颜色列表 + form: form + - name: similarity_strength + type: number + required: false + default: 0.5 + label: + en_US: Similarity Strength + zh_Hans: 相似度强度 + human_description: + en_US: How similar the generated image should be to the input image (0.0 to 1.0) + zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0) + form: form + - name: mask_prompt + type: string + required: false + label: + en_US: Mask Prompt + zh_Hans: 蒙版提示词 + human_description: + en_US: Text description to generate mask for inpainting/outpainting + zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述 + form: llm + - name: outpainting_mode + type: string + required: false + default: DEFAULT + label: + en_US: Outpainting Mode + zh_Hans: 外补绘制模式 + human_description: + en_US: Mode for outpainting (DEFAULT or other supported modes) + zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式) + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.py b/api/core/tools/provider/builtin/aws/tools/nova_reel.py new file mode 100644 index 00000000000000..bfd3d302b22d48 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.py @@ -0,0 +1,371 @@ +import base64 +import logging +import time +from io import BytesIO +from typing import Any, Optional, Union +from urllib.parse import urlparse + +import boto3 +from botocore.exceptions import ClientError +from PIL import Image + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +NOVA_REEL_DEFAULT_REGION = "us-east-1" +NOVA_REEL_DEFAULT_DIMENSION = "1280x720" +NOVA_REEL_DEFAULT_FPS = 24 +NOVA_REEL_DEFAULT_DURATION = 6 +NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0" +NOVA_REEL_STATUS_CHECK_INTERVAL = 5 + +# Image requirements +NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280 +NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720 +NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB" + + +class NovaReelTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Reel model for video generation. + + Args: + user_id: The ID of the user making the request + tool_parameters: Dictionary containing the tool parameters + + Returns: + ToolInvokeMessage containing either the video content or status information + """ + try: + # Validate and extract parameters + params = self._validate_and_extract_parameters(tool_parameters) + if isinstance(params, ToolInvokeMessage): + return params + + # Initialize AWS clients + bedrock, s3_client = self._initialize_aws_clients(params["aws_region"]) + + # Prepare model input + model_input = self._prepare_model_input(params, s3_client) + if isinstance(model_input, ToolInvokeMessage): + return model_input + + # Start video generation + invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"]) + invocation_arn = invocation["invocationArn"] + + # Handle async/sync mode + return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"]) + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + error_message = e.response.get("Error", {}).get("Message", str(e)) + logger.exception(f"AWS API error: {error_code} - {error_message}") + return self.create_text_message(f"AWS service error: {error_code} - {error_message}") + except Exception as e: + logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to generate video: {str(e)}") + + def _validate_and_extract_parameters( + self, tool_parameters: dict[str, Any] + ) -> Union[dict[str, Any], ToolInvokeMessage]: + """Validate and extract parameters from the input dictionary.""" + prompt = tool_parameters.get("prompt", "") + video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip() + + # Validate required parameters + if not prompt: + return self.create_text_message("Please provide a text prompt for video generation.") + if not video_output_s3uri: + return self.create_text_message("Please provide an S3 URI for video output.") + + # Validate S3 URI format + if not video_output_s3uri.startswith("s3://"): + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + # Ensure S3 URI ends with '/' + video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/" + + return { + "prompt": prompt, + "video_output_s3uri": video_output_s3uri, + "image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(), + "aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION), + "dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION), + "seed": int(tool_parameters.get("seed", 0)), + "fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)), + "duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)), + "async_mode": bool(tool_parameters.get("async", True)), + } + + def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]: + """Initialize AWS Bedrock and S3 clients.""" + bedrock = boto3.client(service_name="bedrock-runtime", region_name=region) + s3_client = boto3.client("s3", region_name=region) + return bedrock, s3_client + + def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]: + """Prepare the input for the Nova Reel model.""" + model_input = { + "taskType": "TEXT_VIDEO", + "textToVideoParams": {"text": params["prompt"]}, + "videoGenerationConfig": { + "durationSeconds": params["duration"], + "fps": params["fps"], + "dimension": params["dimension"], + "seed": params["seed"], + }, + } + + # Add image if provided + if params["image_input_s3uri"]: + try: + image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"]) + if not image_data: + return self.create_text_message("Failed to retrieve image from S3") + + # Process and validate image + processed_image = self._process_and_validate_image(image_data) + if isinstance(processed_image, ToolInvokeMessage): + return processed_image + + # Convert processed image to base64 + img_buffer = BytesIO() + processed_image.save(img_buffer, format="PNG") + img_buffer.seek(0) + input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") + + model_input["textToVideoParams"]["images"] = [ + {"format": "png", "source": {"bytes": input_image_base64}} + ] + except Exception as e: + logger.error(f"Error processing input image: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to process input image: {str(e)}") + + return model_input + + def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]: + """ + Process and validate the input image according to Nova Reel requirements. + + Requirements: + - Must be 1280x720 pixels + - Must be RGB format (8 bits per channel) + - If PNG, alpha channel must not have transparent/translucent pixels + """ + try: + # Open image + img = Image.open(BytesIO(image_data)) + + # Convert RGBA to RGB if needed, ensuring no transparency + if img.mode == "RGBA": + # Check for transparency + if img.getchannel("A").getextrema()[0] < 255: + return self.create_text_message( + "PNG image contains transparent or translucent pixels, which is not supported. " + "Please provide an image without transparency." + ) + # Convert to RGB + img = img.convert("RGB") + elif img.mode != "RGB": + # Convert any other mode to RGB + img = img.convert("RGB") + + # Validate/adjust dimensions + if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT): + logger.warning( + f"Image dimensions {img.size} do not match required dimensions " + f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..." + ) + img = img.resize( + (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS + ) + + # Validate bit depth + if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE: + return self.create_text_message( + f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel" + ) + + return img + + except Exception as e: + logger.error(f"Error processing image: {str(e)}", exc_info=True) + return self.create_text_message( + "Failed to process image. Please ensure the image is a valid JPEG or PNG file." + ) + + def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]: + """Download and return image data from S3.""" + parsed_uri = urlparse(s3_uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + response = s3_client.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + + def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]: + """Start the async video generation process.""" + return bedrock.start_async_invoke( + modelId=NOVA_REEL_MODEL_ID, + modelInput=model_input, + outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}}, + ) + + def _handle_generation_mode( + self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool + ) -> ToolInvokeMessage: + """Handle async or sync video generation mode.""" + invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + video_uri = f"{video_path}/output.mp4" + + if async_mode: + return self.create_text_message( + f"Video generation started.\nInvocation ARN: {invocation_arn}\n" + f"Video will be available at: {video_uri}" + ) + + return self._wait_for_completion(bedrock, s3_client, invocation_arn) + + def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage: + """Wait for video generation completion and handle the result.""" + while True: + status_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + status = status_response["status"] + video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + + if status == "Completed": + return self._handle_completed_video(s3_client, video_path) + elif status == "Failed": + failure_message = status_response.get("failureMessage", "Unknown error") + return self.create_text_message(f"Video generation failed.\nError: {failure_message}") + elif status == "InProgress": + time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL) + else: + return self.create_text_message(f"Unexpected status: {status}") + + def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage: + """Handle completed video generation and return the result.""" + parsed_uri = urlparse(video_path) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + "/output.mp4" + + try: + response = s3_client.get_object(Bucket=bucket, Key=key) + video_content = response["Body"].read() + return [ + self.create_text_message(f"Video is available at: {video_path}/output.mp4"), + self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"), + ] + except Exception as e: + logger.error(f"Error downloading video: {str(e)}", exc_info=True) + return self.create_text_message( + f"Video generation completed but failed to download video: {str(e)}\n" + f"Video is available at: s3://{bucket}/{key}" + ) + + def get_runtime_parameters(self) -> list[ToolParameter]: + """Define the tool's runtime parameters.""" + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述" + ), + llm_description="Describe the video you want to generate", + ), + ToolParameter( + name="video_output_s3uri", + label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI" + ), + ), + ToolParameter( + name="dimension", + label=I18nObject(en_US="Dimension", zh_Hans="尺寸"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_DIMENSION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"), + ), + ToolParameter( + name="duration", + label=I18nObject(en_US="Duration", zh_Hans="时长"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_DURATION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"), + ), + ToolParameter( + name="fps", + label=I18nObject(en_US="FPS", zh_Hans="帧率"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_FPS, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数" + ), + ), + ToolParameter( + name="async", + label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"), + type=ToolParameter.ToolParameterType.BOOLEAN, + required=False, + default=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)", + zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)", + ), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_REGION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame", + zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml new file mode 100644 index 00000000000000..16df5ba5c9d1e3 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml @@ -0,0 +1,124 @@ +identity: + name: nova_reel + author: AWS + label: + en_US: AWS Bedrock Nova Reel + zh_Hans: AWS Bedrock Nova Reel + icon: icon.svg +description: + human: + en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution. + +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the video you want to generate + zh_Hans: 您想要生成的视频的文本描述 + llm_description: Describe the video you want to generate + form: llm + + - name: video_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: S3 URI where the generated video will be stored + zh_Hans: 生成的视频将存储的S3 URI + form: form + + - name: dimension + type: string + required: false + default: 1280x720 + label: + en_US: Dimension + zh_Hans: 尺寸 + human_description: + en_US: Video dimensions (width x height) + zh_Hans: 视频尺寸(宽 x 高) + form: form + + - name: duration + type: number + required: false + default: 6 + label: + en_US: Duration + zh_Hans: 时长 + human_description: + en_US: Video duration in seconds + zh_Hans: 视频时长(秒) + form: form + + - name: seed + type: number + required: false + default: 0 + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for video generation + zh_Hans: 视频生成的随机种子 + form: form + + - name: fps + type: number + required: false + default: 24 + label: + en_US: FPS + zh_Hans: 帧率 + human_description: + en_US: Frames per second for the generated video + zh_Hans: 生成视频的每秒帧数 + form: form + + - name: async + type: boolean + required: false + default: true + label: + en_US: Async Mode + zh_Hans: 异步模式 + human_description: + en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion) + zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成) + form: llm + + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input Image S3 URI + zh_Hans: 输入图像S3 URI + human_description: + en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame + zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI + form: llm + +development: + dependencies: + - boto3 + - pillow diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.py b/api/core/tools/provider/builtin/aws/tools/s3_operator.py new file mode 100644 index 00000000000000..e4026b07a87310 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.py @@ -0,0 +1,80 @@ +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class S3Operator(BuiltinTool): + s3_client: Any = None + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + # Initialize S3 client if not already done + if not self.s3_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.s3_client = boto3.client("s3") + + # Parse S3 URI + s3_uri = tool_parameters.get("s3_uri") + if not s3_uri: + return self.create_text_message("s3_uri parameter is required") + + parsed_uri = urlparse(s3_uri) + if parsed_uri.scheme != "s3": + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + bucket = parsed_uri.netloc + # Remove leading slash from key + key = parsed_uri.path.lstrip("/") + + operation_type = tool_parameters.get("operation_type", "read") + generate_presign_url = tool_parameters.get("generate_presign_url", False) + presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour + + if operation_type == "write": + text_content = tool_parameters.get("text_content") + if not text_content: + return self.create_text_message("text_content parameter is required for write operation") + + # Write content to S3 + self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8")) + result = f"s3://{bucket}/{key}" + + # Generate presigned URL for the written object if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + else: # read operation + # Get object from S3 + response = self.s3_client.get_object(Bucket=bucket, Key=key) + result = response["Body"].read().decode("utf-8") + + # Generate presigned URL if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + return self.create_text_message(text=result) + + except self.s3_client.exceptions.NoSuchBucket: + return self.create_text_message(f"Bucket '{bucket}' does not exist") + except self.s3_client.exceptions.NoSuchKey: + return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'") + except Exception as e: + return self.create_text_message(f"Exception: {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml new file mode 100644 index 00000000000000..642fc2966e9b6d --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml @@ -0,0 +1,98 @@ +identity: + name: s3_operator + author: AWS + label: + en_US: AWS S3 Operator + zh_Hans: AWS S3 读写器 + pt_BR: AWS S3 Operator + icon: icon.svg +description: + human: + en_US: AWS S3 Writer and Reader + zh_Hans: 读写S3 bucket中的文件 + pt_BR: AWS S3 Writer and Reader + llm: AWS S3 Writer and Reader +parameters: + - name: text_content + type: string + required: false + label: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + human_description: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + llm_description: The text to write + form: llm + - name: s3_uri + type: string + required: true + label: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + human_description: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + llm_description: s3 uri + form: llm + - name: aws_region + type: string + required: true + label: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + human_description: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + llm_description: region of bucket + form: form + - name: operation_type + type: select + required: true + label: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + human_description: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + default: read + options: + - value: read + label: + en_US: read + zh_Hans: 读 + - value: write + label: + en_US: write + zh_Hans: 写 + form: form + - name: generate_presign_url + type: boolean + required: false + label: + en_US: Generate presigned URL + zh_Hans: 生成预签名URL + human_description: + en_US: Whether to generate a presigned URL for the S3 object + zh_Hans: 是否生成S3对象的预签名URL + default: false + form: form + - name: presign_expiry + type: number + required: false + label: + en_US: Presigned URL expiration time + zh_Hans: 预签名URL有效期 + human_description: + en_US: Expiration time in seconds for the presigned URL + zh_Hans: 预签名URL的有效期(秒) + default: 3600 + form: form From 786cb6859b0fbd056104d1c67cf47d8fcf517b0f Mon Sep 17 00:00:00 2001 From: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Date: Sat, 21 Dec 2024 16:05:04 +0900 Subject: [PATCH 35/91] =?UTF-8?q?fix:=20ensure=20WorkflowRun=20attributes?= =?UTF-8?q?=20are=20refreshed=20in=20WorkflowCycleMana=E2=80=A6=20(#11913)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/app/task_pipeline/workflow_cycle_manage.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index e2fa12b1cddd70..0e6425fa14e8cf 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -505,6 +505,12 @@ def _workflow_finish_to_stream_response( :param workflow_run: workflow run :return: """ + # Attach WorkflowRun to an active session so "created_by_role" can be accessed. + workflow_run = db.session.merge(workflow_run) + + # Refresh to ensure any expired attributes are fully loaded + db.session.refresh(workflow_run) + created_by = None if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: created_by_account = workflow_run.created_by_account From 8c559d6231506ba677ce3652cf574f484d356da2 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:19:46 +0800 Subject: [PATCH 36/91] fix(retrieval_service): avoid to use exception (#11925) Signed-off-by: -LAN- --- api/core/rag/datasource/retrieval_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b2141396d6dcc4..18f8d4e8392302 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -103,7 +103,7 @@ def retrieve( if exceptions: exception_message = ";\n".join(exceptions) - raise Exception(exception_message) + raise ValueError(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( From 8f73670925fc1bf4b45252a84f63f0a10ba54023 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:21:00 +0800 Subject: [PATCH 37/91] fix(file_factory): validate upload_file_id before querying UploadFile (#11937) Signed-off-by: -LAN- --- api/factories/file_factory.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 01d95dcfb3d2fe..13034f5cf5688b 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -116,8 +116,11 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: + upload_file_id = mapping.get("upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") stmt = select(UploadFile).where( - UploadFile.id == mapping.get("upload_file_id"), + UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) From 606aadb891b607f359530ad33542093fc18488ec Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:21:09 +0800 Subject: [PATCH 38/91] refactor: update builtin tool provider methods to use session management (#11938) Signed-off-by: -LAN- --- .../console/workspace/tool_providers.py | 24 ++++++++------ .../tools/builtin_tools_manage_service.py | 33 +++++++++---------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2cd6dcda3b5ea7..9e62a546997b12 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -3,12 +3,14 @@ from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required from services.tools.api_tools_manage_service import ApiToolManageService @@ -91,12 +93,16 @@ def post(self, provider): args = parser.parse_args() - return BuiltinToolManageService.update_builtin_tool_provider( - user_id, - tenant_id, - provider, - args["credentials"], - ) + with Session(db.engine) as session: + result = BuiltinToolManageService.update_builtin_tool_provider( + session=session, + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + ) + session.commit() + return result class ToolBuiltinProviderGetCredentialsApi(Resource): @@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user_id = current_user.id tenant_id = current_user.current_tenant_id return BuiltinToolManageService.get_builtin_tool_provider_credentials( - user_id, - tenant_id, - provider, + tenant_id=tenant_id, + provider_name=provider, ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index e2e49d017ef167..fada881fdeb741 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,9 @@ import logging from pathlib import Path +from sqlalchemy import select +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -32,7 +35,7 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str tenant_id=tenant_id, provider_controller=provider_controller ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = ( + builtin_provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -71,19 +74,18 @@ def list_builtin_provider_credentials_schema(provider_name): return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): """ update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() + stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, ) + provider = session.scalar(stmt) try: # get provider @@ -115,13 +117,10 @@ def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) - db.session.commit() + session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() # delete cache tool_configuration.delete_tool_credentials_cache() @@ -129,15 +128,15 @@ def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ get builtin tool provider credentials """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, + BuiltinToolProvider.provider == provider_name, ) .first() ) @@ -156,7 +155,7 @@ def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st """ delete tool provider """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, From 810adb8a94ecb0e63ddb6416eaeb786b7c7f248e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:21:30 +0800 Subject: [PATCH 39/91] fix: change OutputParserError to inherit from ValueError (#11935) Signed-off-by: -LAN- --- api/core/llm_generator/output_parser/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 1e743f1757473e..0922806ca88ce6 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserError(Exception): +class OutputParserError(ValueError): pass From c07d9e96ce6d3c0e0d135f000150f9d6b03d319d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:21:57 +0800 Subject: [PATCH 40/91] fix(nodes): handle errors in question_classifier and parameter_extractor (#11927) Signed-off-by: -LAN- --- .../parameter_extractor/parameter_extractor_node.py | 9 +++++++++ .../question_classifier/question_classifier_node.py | 3 +-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 5b960ea6151954..c8c854a43b3269 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -179,6 +179,15 @@ def _run(self): error=str(e), metadata={}, ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, + error=str(e), + metadata={}, + ) error = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 15cfabe4780ee2..5043e25e2b8820 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -154,8 +154,7 @@ def _run(self): }, llm_usage=usage, ) - - except ValueError as e: + except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, From 90323cd35595c229d715e37600035a5204813751 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:22:06 +0800 Subject: [PATCH 41/91] fix(tool_file_manager): raise ValueError when get timeout (#11928) Signed-off-by: -LAN- --- api/core/tools/tool_file_manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 5052f0897a2958..2aaca6d82e36b1 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -8,9 +8,10 @@ from typing import Optional, Union from uuid import uuid4 -from httpx import get +import httpx from configs import dify_config +from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage from models.model import MessageFile @@ -94,12 +95,11 @@ def create_file_by_url( ) -> ToolFile: # try to download image try: - response = get(file_url) + response = ssrf_proxy.get(file_url) response.raise_for_status() blob = response.content - except Exception as e: - logger.exception(f"Failed to download file from {file_url}") - raise + except httpx.TimeoutException as e: + raise ValueError(f"timeout when downloading file from {file_url}") mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" From 455791b710c72585b41d9dfadac0deb04c751940 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:22:14 +0800 Subject: [PATCH 42/91] fix(model_runtime): make invoke as ValueError (#11929) Signed-off-by: -LAN- --- api/core/model_runtime/errors/invoke.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index edfb19c7d07d4c..76754253611568 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,7 +1,7 @@ from typing import Optional -class InvokeError(Exception): +class InvokeError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None From b8d42cdea74306b79002c1d6a35e0ec76d173cc9 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:22:47 +0800 Subject: [PATCH 43/91] fix: change MaxRetriesExceededError to inherit from ValueError (#11934) Signed-off-by: -LAN- --- api/core/helper/ssrf_proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index d8aa80536412e6..ce15a9667d8d18 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,7 +24,7 @@ STATUS_FORCELIST = [429, 500, 502, 503, 504] -class MaxRetriesExceededError(Exception): +class MaxRetriesExceededError(ValueError): """Raised when the maximum number of retries is exceeded.""" pass From 0b06235527806e5651f01272bd5e16b214e9234b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:22:57 +0800 Subject: [PATCH 44/91] =?UTF-8?q?fix:=20add=20RemoteFileUploadError=20for?= =?UTF-8?q?=20better=20error=20handling=20in=20remote=20fi=E2=80=A6=20(#11?= =?UTF-8?q?933)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: -LAN- --- api/controllers/common/errors.py | 5 +++++ api/controllers/console/remote_files.py | 13 +++++++++---- api/controllers/web/remote_files.py | 13 +++++++++---- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py index c71f1ce5a31027..9f762b3135e2a4 100644 --- a/api/controllers/common/errors.py +++ b/api/controllers/common/errors.py @@ -4,3 +4,8 @@ class FilenameNotExistsError(HTTPException): code = 400 description = "The specified filename does not exist." + + +class RemoteFileUploadError(HTTPException): + code = 400 + description = "Error uploading remote file." diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index fac1341b39b596..b8cf019e4f068d 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -7,6 +7,7 @@ import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields @@ -43,10 +44,14 @@ def post(self): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index d6b8eb2855725c..ae68df6bdc4e48 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -5,6 +5,7 @@ import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy @@ -38,10 +39,14 @@ def post(self, app_model, end_user): # Add app_model and end_user parameters url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) From 5e37ab60d8c6f9aded0a72e0d4180bdac9c3795b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:23:03 +0800 Subject: [PATCH 45/91] fix: validate response type in transform_response method (#11931) Signed-off-by: -LAN- --- api/core/helper/code_executor/template_transformer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b7a07b21e1d784..cf422fd023e01e 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -33,13 +33,16 @@ def extract_result_str_from_response(cls, response: str) -> str: return result @classmethod - def transform_response(cls, response: str) -> dict: + def transform_response(cls, response: str): """ Transform response to dict :param response: response :return: """ - return json.loads(cls.extract_result_str_from_response(response)) + result = json.loads(cls.extract_result_str_from_response(response)) + if not isinstance(result, dict): + raise ValueError("Result must be a dict") + return result @classmethod @abstractmethod From 599d410d9999785c70f3f2353700c9057a334611 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:23:12 +0800 Subject: [PATCH 46/91] fix: validate reranking model attributes before processing (#11930) Signed-off-by: -LAN- --- api/core/rag/data_post_processor/data_post_processor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 992415657eced2..d17d76333ee705 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -83,11 +83,15 @@ def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[d if reranking_model: try: model_manager = ModelManager() + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") + if not reranking_provider_name or not reranking_model_name: + return None rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], + provider=reranking_provider_name, model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], + model=reranking_model_name, ) return rerank_model_instance except InvokeAuthorizationError: From a227af3664e3108a3c2956d7d363b93c7d08bc9a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 21 Dec 2024 21:24:22 +0800 Subject: [PATCH 47/91] =?UTF-8?q?fix(code=5Fnode):=20update=20type=20hints?= =?UTF-8?q?=20for=20string=20and=20number=20checks=20in=20Cod=E2=80=A6=20(?= =?UTF-8?q?#11936)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: -LAN- --- api/core/workflow/nodes/code/code_node.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 19b9078a5ce4a4..6adf82c455147f 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Optional from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -67,18 +67,17 @@ def _run(self) -> NodeRunResult: return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _check_string(self, value: str, variable: str) -> str: + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, str): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: raise OutputValidationError( @@ -88,18 +87,17 @@ def _check_string(self, value: str, variable: str) -> str: return value.replace("\x00", "") - def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + def _check_number(self, value: int | float | None, variable: str) -> int | float | None: """ Check number :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, int | float): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: raise OutputValidationError( From e22cc2811462de33757e688c3e93ec042e897909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Sat, 21 Dec 2024 21:24:33 +0800 Subject: [PATCH 48/91] fix:log error(#11942) (#11943) --- api/models/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/models/model.py b/api/models/model.py index 8608b12af1fc8e..d3c69cfbe1278e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -700,8 +700,10 @@ def admin_feedback_stats(self): def status_count(self): messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() status_counts = { + WorkflowRunStatus.RUNNING: 0, WorkflowRunStatus.SUCCEEDED: 0, WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.STOPPED: 0, WorkflowRunStatus.PARTIAL_SUCCESSED: 0, } From 9ee9e9c6de52aa01a8840bfc7f7a2feed6af6312 Mon Sep 17 00:00:00 2001 From: yihong Date: Sat, 21 Dec 2024 21:24:59 +0800 Subject: [PATCH 49/91] fix: self.method should method in api_tool.py (#11926) Signed-off-by: yihong0618 --- api/core/tools/tool/api_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 636debffd46f05..48aac75dbb4115 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -210,7 +210,7 @@ def do_http_request( ) return response else: - raise ValueError(f"Invalid http method {self.method}") + raise ValueError(f"Invalid http method {method}") def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 From 9578246bbb2029251880fd4745e26360aa14813c Mon Sep 17 00:00:00 2001 From: jiangbo721 <365065261@qq.com> Date: Sat, 21 Dec 2024 23:13:58 +0800 Subject: [PATCH 50/91] fix: The default updated_at when a workflow is created (#11709) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 刘江波 --- api/models/account.py | 21 ++++----- api/models/api_based_extension.py | 4 +- api/models/dataset.py | 32 +++++++------- api/models/model.py | 72 +++++++++++++++---------------- api/models/provider.py | 30 +++++++------ api/models/source.py | 9 ++-- api/models/tools.py | 26 +++++------ api/models/web.py | 6 ++- api/models/workflow.py | 16 +++---- 9 files changed, 110 insertions(+), 106 deletions(-) diff --git a/api/models/account.py b/api/models/account.py index ce17b90def3fa8..932ba1da578b8a 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,6 +2,7 @@ import json from flask_login import UserMixin +from sqlalchemy import func from .engine import db from .types import StringUUID @@ -30,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def is_password_set(self): @@ -187,8 +188,8 @@ class Tenant(db.Model): plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -228,8 +229,8 @@ class TenantAccountJoin(db.Model): current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AccountIntegrate(db.Model): @@ -245,8 +246,8 @@ class AccountIntegrate(db.Model): provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class InvitationCode(db.Model): @@ -265,4 +266,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 4d4182cabdc163..fbffe7a3b2ee9d 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,5 +1,7 @@ import enum +from sqlalchemy import func + from .engine import db from .types import StringUUID @@ -23,4 +25,4 @@ class APIBasedExtension(db.Model): name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 97e4d6c0ef8981..7279e8d5b3394a 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -50,9 +50,9 @@ class Dataset(db.Model): indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model): mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -264,7 +264,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -303,7 +303,7 @@ class Document(db.Model): archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) @@ -527,9 +527,9 @@ class DocumentSegment(db.Model): disabled_by = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -697,7 +697,7 @@ class Embedding(db.Model): ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): @@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model): model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(db.Model): @@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) account = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(db.Model): @@ -751,7 +751,7 @@ class Whitelist(db.Model): id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(db.Model): @@ -768,7 +768,7 @@ class DatasetPermission(db.Model): account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(db.Model): @@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model): tenant_id = db.Column(StringUUID, nullable=False) settings = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model): dataset_id = db.Column(StringUUID, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index d3c69cfbe1278e..04bb0a947a8ec3 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -30,7 +30,7 @@ class DifySetup(db.Model): __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -85,9 +85,9 @@ class App(db.Model): tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -226,9 +226,9 @@ class AppModelConfig(db.Model): model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -482,8 +482,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -507,7 +507,7 @@ class InstalledApp(db.Model): position = db.Column(db.Integer, nullable=False, default=0) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -548,8 +548,8 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( @@ -791,8 +791,8 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1117,8 +1117,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1164,9 +1164,7 @@ def __init__( upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(db.Model): @@ -1186,8 +1184,8 @@ class MessageAnnotation(db.Model): content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1216,7 +1214,7 @@ class AppAnnotationHitHistory(db.Model): source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) @@ -1250,9 +1248,9 @@ class AppAnnotationSetting(db.Model): score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_account(self): @@ -1298,9 +1296,9 @@ class OperationLog(db.Model): account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(UserMixin, db.Model): @@ -1319,8 +1317,8 @@ class EndUser(UserMixin, db.Model): name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Site(db.Model): @@ -1351,9 +1349,9 @@ class Site(db.Model): prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) code = db.Column(db.String(255)) @property @@ -1395,7 +1393,7 @@ class ApiToken(db.Model): type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1426,9 +1424,7 @@ class UploadFile(db.Model): db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) @@ -1485,7 +1481,7 @@ class ApiRequest(db.Model): request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(db.Model): @@ -1657,7 +1653,7 @@ class Tag(db.Model): type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(db.Model): @@ -1673,7 +1669,7 @@ class TagBinding(db.Model): tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(db.Model): @@ -1687,8 +1683,10 @@ class TraceAppConfig(db.Model): app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property diff --git a/api/models/provider.py b/api/models/provider.py index 65f70b76e9c362..fdd3e802d79211 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,7 @@ from enum import Enum +from sqlalchemy import func + from .engine import db from .types import StringUUID @@ -60,8 +62,8 @@ class Provider(db.Model): quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -108,8 +110,8 @@ class ProviderModel(db.Model): model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(db.Model): @@ -124,8 +126,8 @@ class TenantDefaultModel(db.Model): provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(db.Model): @@ -139,8 +141,8 @@ class TenantPreferredModelProvider(db.Model): tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(db.Model): @@ -164,8 +166,8 @@ class ProviderOrder(db.Model): paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(db.Model): @@ -186,8 +188,8 @@ class ProviderModelSetting(db.Model): model_type = db.Column(db.String(40), nullable=False) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(db.Model): @@ -209,5 +211,5 @@ class LoadBalancingModelConfig(db.Model): name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 4d98572ef81fcb..114db8e1100e5d 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,5 +1,6 @@ import json +from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from .engine import db @@ -19,8 +20,8 @@ class DataSourceOauthBinding(db.Model): access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) @@ -37,8 +38,8 @@ class DataSourceApiKeyAuthBinding(db.Model): category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): diff --git a/api/models/tools.py b/api/models/tools.py index c390be4625aab2..e90ab669c66f1e 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -2,7 +2,7 @@ from typing import Optional import sqlalchemy as sa -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject @@ -36,8 +36,8 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def credentials(self) -> dict: @@ -74,8 +74,8 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def description_i18n(self) -> I18nObject: @@ -120,8 +120,8 @@ class ApiToolProvider(db.Model): # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -198,8 +198,8 @@ class WorkflowToolProvider(db.Model): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def user(self) -> Account | None: @@ -251,8 +251,8 @@ class ToolModelInvoke(db.Model): provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ToolConversationVariables(db.Model): @@ -278,8 +278,8 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> dict: diff --git a/api/models/web.py b/api/models/web.py index a0f87cf4566f04..f5d54630e17361 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,3 +1,5 @@ +from sqlalchemy import func + from .engine import db from .model import Message from .types import StringUUID @@ -15,7 +17,7 @@ class SavedMessage(db.Model): message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -34,4 +36,4 @@ class PinnedConversation(db.Model): conversation_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index e933382a841a6f..0e5cf38d25e864 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from datetime import UTC, datetime +from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional, Union @@ -103,12 +103,10 @@ class Workflow(db.Model): graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() + db.DateTime, nullable=False, server_default=func.current_timestamp(), server_onupdate=func.current_timestamp() ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -406,7 +404,7 @@ class WorkflowRun(db.Model): total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @@ -636,7 +634,7 @@ class WorkflowNodeExecution(db.Model): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -755,7 +753,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -781,7 +779,7 @@ class ConversationVariable(db.Model): conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp()) updated_at = db.Column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) From 366857cd26e563d66c3802524f690e80b4a7fc81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 23:14:05 +0800 Subject: [PATCH 51/91] fix: gemini system prompt with variable raise error (#11946) --- .../model_providers/google/llm/llm.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b54668a12d3207..7d19ccbb74a011 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -21,6 +21,7 @@ PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -143,7 +144,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: - ping_message = SystemPromptMessage(content="ping") + ping_message = UserPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_output_tokens": 5}) except Exception as ex: @@ -187,17 +188,23 @@ def _generate( config_kwargs["stop_sequences"] = stop genai.configure(api_key=credentials["google_api_key"]) - google_model = genai.GenerativeModel(model_name=model) history = [] + system_instruction = None for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) + elif content["role"] == "system": + system_instruction = content["parts"][0] else: history.append(content) + if not history: + raise InvokeError("The user prompt message is required. You only add a system prompt message.") + + google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), @@ -404,7 +411,10 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: ) return glm_content elif isinstance(message, SystemPromptMessage): - return {"role": "user", "parts": [to_part(message.content)]} + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) + return {"role": "system", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", From 3d07a94bd74cc3eddf4d2e530e575c422d5a1675 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:39:29 +0800 Subject: [PATCH 52/91] =?UTF-8?q?fix:=20refactor=20conversation=20paginati?= =?UTF-8?q?on=20to=20use=20SQLAlchemy=20session=20manag=E2=80=A6=20(#11956?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: -LAN- --- .../console/explore/conversation.py | 20 +++++--- .../service_api/app/conversation.py | 20 +++++--- api/controllers/web/conversation.py | 22 ++++---- api/models/web.py | 3 +- api/services/conversation_service.py | 50 ++++++++++--------- api/services/web_conversation_service.py | 18 ++++--- 6 files changed, 78 insertions(+), 55 deletions(-) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 6f9d7769b942ce..5e7a3da017edd7 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,12 +1,14 @@ from flask_login import current_user from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -34,14 +36,16 @@ def get(self, installed_app): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=current_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=current_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.EXPLORE, + pinned=pinned, + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c62fd77d367aa6..32940cbc29f355 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,6 @@ from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound import services @@ -7,6 +8,7 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import ( conversation_delete_fields, conversation_infinite_scroll_pagination_fields, @@ -39,14 +41,16 @@ def get(self, app_model: App, end_user: EndUser): args = parser.parse_args() try: - return ConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return ConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], + ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index c3b0cd4f44b2ac..fe0d7c74f32cff 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,11 +1,13 @@ from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -40,15 +42,17 @@ def get(self, app_model, end_user): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.WEB_APP, + pinned=pinned, + sort_by=args["sort_by"], + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/models/web.py b/api/models/web.py index f5d54630e17361..028a768519d99a 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,5 @@ from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column from .engine import db from .model import Message @@ -33,7 +34,7 @@ class PinnedConversation(db.Model): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 8642972710fd1f..456dc3ebebaa28 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import UTC, datetime from typing import Optional, Union -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, func, or_, select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -18,19 +19,21 @@ class ConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, + include_ids: Optional[Sequence[str]] = None, + exclude_ids: Optional[Sequence[str]] = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Conversation).filter( + stmt = select(Conversation).where( Conversation.is_deleted == False, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -38,37 +41,40 @@ def pagination_by_last_id( Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - if include_ids is not None: - base_query = base_query.filter(Conversation.id.in_(include_ids)) - + stmt = stmt.where(Conversation.id.in_(include_ids)) if exclude_ids is not None: - base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) if last_id: - last_conversation = base_query.filter(Conversation.id == last_id).first() + last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) if not last_conversation: raise LastConversationNotExistsError() # build filters based on sorting - filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) - base_query = base_query.filter(filter_condition) - - base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) - - conversations = base_query.limit(limit).all() + filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=last_conversation, + ) + stmt = stmt.where(filter_condition) + query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) + conversations = session.scalars(query_stmt).all() has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] rest_filter_condition = cls._build_filter_condition( - sort_field, sort_direction, current_page_last_conversation, is_next_page=True + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=current_page_last_conversation, ) - rest_count = base_query.filter(rest_filter_condition).count() - + count_stmt = stmt.where(rest_filter_condition) + count_stmt = select(func.count()).select_from(count_stmt.subquery()) + rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True @@ -81,11 +87,9 @@ def _get_sort_params(cls, sort_by: str): return sort_by, asc @classmethod - def _build_filter_condition( - cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False - ): + def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + if sort_direction == desc: return getattr(Conversation, sort_field) < field_value else: return getattr(Conversation, sort_field) > field_value diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index d7ccc964cb70f8..508fe20970a703 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,5 +1,8 @@ from typing import Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -13,6 +16,8 @@ class WebConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], @@ -23,24 +28,25 @@ def pagination_by_last_id( ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None - if pinned is not None: - pinned_conversations = ( - db.session.query(PinnedConversation) - .filter( + if pinned is not None and user: + stmt = ( + select(PinnedConversation.conversation_id) + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) .order_by(PinnedConversation.created_at.desc()) - .all() ) - pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] + pinned_conversation_ids = session.scalars(stmt).all() + if pinned: include_ids = pinned_conversation_ids else: exclude_ids = pinned_conversation_ids return ConversationService.pagination_by_last_id( + session=session, app_model=app_model, user=user, last_id=last_id, From 2ad2a402fb1c582b6318b463b7e926cfed2aff77 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:40:12 +0800 Subject: [PATCH 53/91] fix(app_dsl_service): handle missing app mode with a ValueError (#11945) Signed-off-by: -LAN- --- api/services/app_dsl_service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8180c3b400719a..0478903fa4d2d7 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -340,7 +340,10 @@ def _create_or_update_app( ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) - app_mode = AppMode(app_data["mode"]) + app_mode = app_data.get("mode") + if not app_mode: + raise ValueError("loss app mode") + app_mode = AppMode(app_mode) # Set icon type icon_type_value = icon_type or app_data.get("icon_type") From a056a9d6010a2dfb5bb9ced1f6d50104016b2176 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:40:43 +0800 Subject: [PATCH 54/91] feat(code_node): add more check (#11949) Signed-off-by: -LAN- --- api/core/helper/code_executor/code_executor.py | 2 +- .../helper/code_executor/template_transformer.py | 13 +++++++++---- api/core/workflow/nodes/code/code_node.py | 16 +++++++--------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 011ff382ead462..584e3e9698a88d 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -118,7 +118,7 @@ def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: return response.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): """ Execute code :param language: code language diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index cf422fd023e01e..605719747a7b56 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -25,7 +25,7 @@ def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, st return runner_script, preload_script @classmethod - def extract_result_str_from_response(cls, response: str) -> str: + def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") @@ -33,15 +33,20 @@ def extract_result_str_from_response(cls, response: str) -> str: return result @classmethod - def transform_response(cls, response: str): + def transform_response(cls, response: str) -> Mapping[str, Any]: """ Transform response to dict :param response: response :return: """ - result = json.loads(cls.extract_result_str_from_response(response)) + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") if not isinstance(result, dict): - raise ValueError("Result must be a dict") + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") return result @classmethod diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 6adf82c455147f..4e371ca43645a5 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -59,7 +59,7 @@ def _run(self) -> NodeRunResult: ) # Transform result - result = self._transform_result(result, self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -116,14 +116,12 @@ def _check_number(self, value: int | float | None, variable: str) -> int | float return value def _transform_result( - self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 - ) -> dict: - """ - Transform result - :param result: result - :param output_schema: output schema - :return: - """ + self, + result: Mapping[str, Any], + output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = "", + depth: int = 1, + ): if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") From 90f093eb671ae8aef693f2e41bfb2780d8369181 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:40:56 +0800 Subject: [PATCH 55/91] fix(json_in_md_parser): improve error messages for JSON parsing failures (#11948) Signed-off-by: -LAN- --- api/libs/json_in_md_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41c5d20c4b08b9..267af611f5e8cb 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict: extracted_content = json_string[start_index:end_index].strip() parsed = json.loads(extracted_content) else: - raise Exception("Could not find JSON block in the output.") + raise ValueError("could not find json block in the output.") return parsed @@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserError(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"got invalid json object. error: {e}") for key in expected_keys: if key not in json_obj: raise OutputParserError( - f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" ) return json_obj From dd0e81d0940c492df5ea5b4cc17bf259b8a33cca Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:41:06 +0800 Subject: [PATCH 56/91] =?UTF-8?q?fix:=20enhance=20type=20hints=20and=20imp?= =?UTF-8?q?rove=20audio=20message=20handling=20in=20TTS=20pub=E2=80=A6=20(?= =?UTF-8?q?#11947)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: -LAN- --- .../app_generator_tts_publisher.py | 32 +++++++++++-------- .../advanced_chat/generate_task_pipeline.py | 10 +++--- api/core/app/apps/base_app_queue_manager.py | 7 ++-- .../apps/workflow/generate_task_pipeline.py | 10 +++--- .../easy_ui_based_generate_task_pipeline.py | 6 ++-- 5 files changed, 36 insertions(+), 29 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 18b115dfe40d3c..29709914b7cfb8 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -4,14 +4,17 @@ import queue import re import threading +from collections.abc import Iterable from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentMessageEvent, QueueLLMChunkEvent, QueueNodeSucceededEvent, QueueTextChunkEvent, + WorkflowQueueMessage, ) -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -21,7 +24,7 @@ def __init__(self, status: str, audio): self.status = status -def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): ) -def _process_future(future_queue, audio_queue): +def _process_future( + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None], + audio_queue: queue.Queue[AudioTrunk], +): while True: try: future = future_queue.get() if future is None: break - for audio in future.result(): + invoke_result = future.result() + if not invoke_result: + continue + for audio in invoke_result: audio_base64 = base64.b64encode(bytes(audio)) audio_queue.put(AudioTrunk("responding", audio=audio_base64)) except Exception as e: @@ -49,8 +58,8 @@ def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" - self._audio_queue = queue.Queue() - self._msg_queue = queue.Queue() + self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() + self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( @@ -66,14 +75,11 @@ def __init__(self, tenant_id: str, voice: str): self._runtime_thread = threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) - def publish(self, message): - try: - self._msg_queue.put(message) - except Exception as e: - self.logger.warning(e) + def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): + self._msg_queue.put(message) def _runtime(self): - future_queue = queue.Queue() + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue() threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() while True: try: @@ -110,7 +116,7 @@ def _runtime(self): break future_queue.put(None) - def check_and_get_audio(self) -> AudioTrunk | None: + def check_and_get_audio(self): try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ce0e95962772ad..7c8f9f3813b378 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -197,11 +197,11 @@ def _to_stream_response( stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -222,7 +222,7 @@ def _wrapper_process_stream_response( for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -511,7 +511,7 @@ def _process_stream_response( # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._message_to_stream_response( diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 4c4d282e99b6ae..3725c6e6ddc4fd 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,7 +1,6 @@ import queue import time from abc import abstractmethod -from collections.abc import Generator from enum import Enum from typing import Any @@ -11,9 +10,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueErrorEvent, QueuePingEvent, QueueStopEvent, + WorkflowQueueMessage, ) from extensions.ext_redis import redis_client @@ -37,11 +38,11 @@ def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) - q = queue.Queue() + q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q - def listen(self) -> Generator: + def listen(self): """ Listen to queue :return: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 79e5e2bcb96d60..d279002285f4c2 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -171,11 +171,11 @@ def _to_stream_response( yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -196,7 +196,7 @@ def _wrapper_process_stream_response( for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -421,7 +421,7 @@ def _process_stream_response( # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 917649f34e769c..4216cd46cfbc52 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -201,11 +201,11 @@ def _to_stream_response( stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None From 5db8addcc61221e600ab9869e135c577e51f43ce Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:41:34 +0800 Subject: [PATCH 57/91] fix(core/errors): change base class of custom exceptions to ValueError (#11955) Signed-off-by: -LAN- --- api/core/errors/error.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 3b186476ebe977..ad921bc2556ffe 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,7 +1,7 @@ from typing import Optional -class LLMError(Exception): +class LLMError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None @@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError): description = "Bad Request" -class ProviderTokenNotInitError(Exception): +class ProviderTokenNotInitError(ValueError): """ Custom exception raised when the provider token is not initialized. """ @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): self.description = args[0] if args else self.description -class QuotaExceededError(Exception): +class QuotaExceededError(ValueError): """ Custom exception raised when the quota for a provider has been exceeded. """ @@ -35,7 +35,7 @@ class QuotaExceededError(Exception): description = "Quota Exceeded" -class AppInvokeQuotaExceededError(Exception): +class AppInvokeQuotaExceededError(ValueError): """ Custom exception raised when the quota for an app has been exceeded. """ @@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception): description = "App Invoke Quota Exceeded" -class ModelCurrentlyNotSupportError(Exception): +class ModelCurrentlyNotSupportError(ValueError): """ Custom exception raised when the model not support """ @@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception): description = "Model Currently Not Support" -class InvokeRateLimitError(Exception): +class InvokeRateLimitError(ValueError): """Raised when the Invoke returns rate limit error.""" description = "Rate Limit Error" From 2c4df108e56a54ff0d8b44e4cabf50273cfc2fae Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:41:53 +0800 Subject: [PATCH 58/91] fix: raise http request node error on httpx.request error (#11954) Signed-off-by: -LAN- --- api/core/ops/ops_trace_manager.py | 8 +++++++- api/core/workflow/nodes/http_request/executor.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index b7799ce1fbdd5e..a04fc6ee789fd0 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -355,7 +355,13 @@ def preprocess(self): def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): + def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): + if not workflow_run: + raise ValueError("Workflow run not found") + + db.session.merge(workflow_run) + db.sessoin.refresh(workflow_run) + workflow_id = workflow_run.workflow_id tenant_id = workflow_run.tenant_id workflow_run_id = workflow_run.id diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 92f190091badeb..b96402a76af0bc 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -249,6 +249,8 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: # request_args = {k: v for k, v in request_args.items() if v is not None} try: response = getattr(ssrf_proxy, self.method)(**request_args) + except httpx.RequestError as e: + raise HttpRequestNodeError(str(e)) except ssrf_proxy.MaxRetriesExceededError as e: raise HttpRequestNodeError(str(e)) return response From 21a31d7f8bc4eb60080ff3cd7228a19fcabb14fa Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:42:30 +0800 Subject: [PATCH 59/91] fix(base_node): change BaseNodeError exception type from Exception to ValueError (#11952) Signed-off-by: -LAN- --- api/core/workflow/nodes/base/exc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py index ec134e031cf9d3..aeecf406403e6d 100644 --- a/api/core/workflow/nodes/base/exc.py +++ b/api/core/workflow/nodes/base/exc.py @@ -1,4 +1,4 @@ -class BaseNodeError(Exception): +class BaseNodeError(ValueError): """Base class for node errors.""" pass From c6a72def8850e5bc5983f5950f98a485592f2db1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:42:51 +0800 Subject: [PATCH 60/91] fix(ops_trace_manager): handle None workflow_run in workflow_trace method and raise ValueError (#11953) Signed-off-by: -LAN- From 10caab1729cbbdcb786c99454ba9b20413af364f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:43:31 +0800 Subject: [PATCH 61/91] fix: change CredentialsValidateFailedError to inherit from ValueError (#11950) Signed-off-by: -LAN- --- api/core/model_runtime/errors/validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 7fcd2133f9f8d1..16bebcc67db062 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -1,4 +1,4 @@ -class CredentialsValidateFailedError(Exception): +class CredentialsValidateFailedError(ValueError): """ Credentials validate failed error """ From 03ddee3663eedd9bfa41c998b5035c6f57775171 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:43:40 +0800 Subject: [PATCH 62/91] fix(variable_assigner): change VariableOperatorNodeError to inherit from ValueError (#11951) Signed-off-by: -LAN- --- api/core/workflow/nodes/variable_assigner/common/exc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/core/workflow/nodes/variable_assigner/common/exc.py index a1178fb0203593..f8dbedc2901c9f 100644 --- a/api/core/workflow/nodes/variable_assigner/common/exc.py +++ b/api/core/workflow/nodes/variable_assigner/common/exc.py @@ -1,4 +1,4 @@ -class VariableOperatorNodeError(Exception): +class VariableOperatorNodeError(ValueError): """Base error type, don't use directly.""" pass From 6b49889041f5d25cd2c5bc2e0f7915abb7027027 Mon Sep 17 00:00:00 2001 From: liuhaoran <75237518+liuhaoran1212@users.noreply.github.com> Date: Sun, 22 Dec 2024 10:45:55 +0800 Subject: [PATCH 63/91] fix: messagefeedbackapi support content (#11716) Signed-off-by: weiyang <24080293@smb956101.com> Co-authored-by: weiyang <24080293@smb956101.com> --- api/controllers/service_api/app/message.py | 3 ++- api/services/message_service.py | 9 ++++++++- web/app/components/develop/template/template.en.mdx | 8 ++++++-- web/app/components/develop/template/template.ja.mdx | 8 ++++++-- web/app/components/develop/template/template.zh.mdx | 8 ++++++-- .../develop/template/template_advanced_chat.en.mdx | 8 ++++++-- .../develop/template/template_advanced_chat.ja.mdx | 8 ++++++-- .../develop/template/template_advanced_chat.zh.mdx | 8 ++++++-- web/app/components/develop/template/template_chat.en.mdx | 8 ++++++-- web/app/components/develop/template/template_chat.ja.mdx | 8 ++++++-- web/app/components/develop/template/template_chat.zh.mdx | 8 ++++++-- 11 files changed, 64 insertions(+), 20 deletions(-) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index ada40ec9cb26bd..599401bc6f1821 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -104,10 +104,11 @@ def post(self, app_model: App, end_user: EndUser, message_id): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/services/message_service.py b/api/services/message_service.py index f432a77c80e511..be2922f4c58e76 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -151,7 +151,12 @@ def pagination_by_last_id( @classmethod def create_feedback( - cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + cls, + app_model: App, + message_id: str, + user: Optional[Union[Account, EndUser]], + rating: Optional[str], + content: Optional[str], ) -> MessageFeedback: if not user: raise ValueError("user cannot be None") @@ -164,6 +169,7 @@ def create_feedback( db.session.delete(feedback) elif rating and feedback: feedback.rating = rating + feedback.content = content elif not rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -172,6 +178,7 @@ def create_feedback( conversation_id=message.conversation_id, message_id=message.id, rating=rating, + content=content, from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx index f469076bf3c80d..877955039c82ca 100755 --- a/web/app/components/develop/template/template.en.mdx +++ b/web/app/components/develop/template/template.en.mdx @@ -346,6 +346,9 @@ The text generation application offers non-session support and is ideal for tran User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -353,7 +356,7 @@ The text generation application offers non-session support and is ideal for tran - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -361,7 +364,8 @@ The text generation application offers non-session support and is ideal for tran --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template.ja.mdx b/web/app/components/develop/template/template.ja.mdx index bd92bd7f36b103..c3b376f2e743c2 100755 --- a/web/app/components/develop/template/template.ja.mdx +++ b/web/app/components/develop/template/template.ja.mdx @@ -345,6 +345,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 開発者のルールで定義されたユーザー識別子。アプリケーション内で一意である必要があります。 + + メッセージのフィードバックです。 + ### レスポンス @@ -352,7 +355,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -360,7 +363,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index 7b1bec35469c36..be7470480f36df 100755 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -320,6 +320,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -327,7 +330,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -335,7 +338,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 5f00977e75ac85..5106b6f36a8757 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -444,6 +444,9 @@ Chat applications support session persistence, allowing previous chat history to User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -451,7 +454,7 @@ Chat applications support session persistence, allowing previous chat history to - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -459,7 +462,8 @@ Chat applications support session persistence, allowing previous chat history to --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_advanced_chat.ja.mdx b/web/app/components/develop/template/template_advanced_chat.ja.mdx index 7c933598f9a440..cf65a29b446de1 100644 --- a/web/app/components/develop/template/template_advanced_chat.ja.mdx +++ b/web/app/components/develop/template/template_advanced_chat.ja.mdx @@ -444,6 +444,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。 + + メッセージのフィードバックです。 + ### 応答 @@ -451,7 +454,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -459,7 +462,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index fec0636d4092aa..662309525b9264 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -450,6 +450,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -457,7 +460,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -465,7 +468,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 1eb289b3c1d5c6..d38e80407add80 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -408,6 +408,9 @@ Chat applications support session persistence, allowing previous chat history to User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -415,7 +418,7 @@ Chat applications support session persistence, allowing previous chat history to - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -423,7 +426,8 @@ Chat applications support session persistence, allowing previous chat history to --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.ja.mdx b/web/app/components/develop/template/template_chat.ja.mdx index fb686e0cffdaca..96db9912d5da71 100644 --- a/web/app/components/develop/template/template_chat.ja.mdx +++ b/web/app/components/develop/template/template_chat.ja.mdx @@ -408,6 +408,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。 + + メッセージのフィードバックです。 + ### 応答 @@ -415,7 +418,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -423,7 +426,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index af96cab5ffcb4e..3d6e3630be5129 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -423,6 +423,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -430,7 +433,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -438,7 +441,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` From 750662eb08670c7289b943fa515334e5a32883ad Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 14:55:18 +0800 Subject: [PATCH 64/91] fix(workflow): update updated_at default to use UTC timezone (#11960) Signed-off-by: -LAN- --- api/models/workflow.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/api/models/workflow.py b/api/models/workflow.py index 0e5cf38d25e864..7896339f37d00d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from datetime import datetime +from datetime import UTC, datetime from enum import Enum, StrEnum from typing import Any, Optional, Union @@ -106,7 +106,10 @@ class Workflow(db.Model): created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), server_onupdate=func.current_timestamp() + db.DateTime, + nullable=False, + default=datetime.now(UTC).replace(tzinfo=None), + server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" From 26c10b9931bc2ea3ded8030c1de97760f178c522 Mon Sep 17 00:00:00 2001 From: yihong Date: Sun, 22 Dec 2024 14:58:33 +0800 Subject: [PATCH 65/91] fix: 'dict_keys' object is not subscriptable error (#11957) Signed-off-by: yihong0618 --- api/core/tools/provider/app_tool_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 09f328cd1fe65f..582ad636b1953a 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -62,7 +62,7 @@ def get_tools(self, user_id: str) -> list[Tool]: user_input_form_list = app_model_config.user_input_form_list for input_form in user_input_form_list: # get type - form_type = input_form.keys()[0] + form_type = list(input_form.keys())[0] default = input_form[form_type]["default"] required = input_form[form_type]["required"] label = input_form[form_type]["label"] From d9875fe232029bfd3f34c9be7bade1a7313ff9e1 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 23 Dec 2024 09:20:30 +0800 Subject: [PATCH 66/91] fix(commands): validate name encoding for non-Latin characters (#11965) --- api/commands.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/commands.py b/api/commands.py index 09548ac9f338cd..bf013cc77e0627 100644 --- a/api/commands.py +++ b/api/commands.py @@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str if language not in languages: language = "en-US" - name = name.strip() + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None # generate random password new_password = secrets.token_urlsafe(16) From 39df994ff9c48723d84fc78065ac3996c93c6084 Mon Sep 17 00:00:00 2001 From: yihong Date: Mon, 23 Dec 2024 09:20:47 +0800 Subject: [PATCH 67/91] fix: create_feedback args are wrong (#11962) Signed-off-by: yihong0618 --- api/controllers/console/explore/message.py | 2 +- api/controllers/web/message.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3d221ff30a6599..4e11d8005f138b 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -70,7 +70,7 @@ def post(self, installed_app, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 98891f5d00d7e0..febaab5328e8b3 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -108,7 +108,7 @@ def post(self, app_model, end_user, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") From 74b1b601253d328f3369d1df2d95531e6c251ab9 Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:17:49 +0800 Subject: [PATCH 68/91] Feat: account page dark mode (#11977) --- web/app/account/account-page/index.tsx | 36 +++++++++---------- web/app/account/avatar.tsx | 12 +++---- web/app/account/layout.tsx | 2 +- web/app/components/base/modal/index.tsx | 12 +++---- web/app/components/base/select/index.tsx | 17 ++++----- .../header/account-dropdown/index.tsx | 35 +++++++++--------- .../header/account-setting/collapse/index.tsx | 10 +++--- .../data-source-page/index.tsx | 1 - .../data-source-page/panel/config-item.tsx | 14 ++++---- .../data-source-page/panel/index.tsx | 24 ++++++------- .../header/account-setting/index.tsx | 14 ++++---- .../account-setting/language-page/index.tsx | 2 +- .../account-setting/members-page/index.tsx | 22 ++++++------ web/app/components/header/header-wrapper.tsx | 2 +- web/app/components/header/index.tsx | 2 +- web/app/components/header/indicator/index.tsx | 20 +++++------ web/themes/manual-dark.css | 31 +++++++++++++--- web/themes/manual-light.css | 31 +++++++++++++--- web/themes/tailwind-theme-var-define.ts | 1 + 19 files changed, 165 insertions(+), 123 deletions(-) diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index 71540ce3b1265a..c7af05793f296b 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -18,10 +18,10 @@ import { IS_CE_EDITION } from '@/config' import Input from '@/app/components/base/input' const titleClassName = ` - text-sm font-medium text-gray-900 + system-sm-semibold text-text-secondary ` const descriptionClassName = ` - mt-1 text-xs font-normal text-gray-500 + mt-1 body-xs-regular text-text-tertiary ` const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ @@ -122,7 +122,7 @@ export default function AccountPage() {
-
{item.name}
+
{item.name}
) } @@ -130,7 +130,7 @@ export default function AccountPage() { return ( <>
-

{t('common.account.myAccount')}

+

{t('common.account.myAccount')}

@@ -142,10 +142,10 @@ export default function AccountPage() {
{t('common.account.name')}
-
+
{userProfile.name}
-
+
{t('common.operation.edit')}
@@ -153,7 +153,7 @@ export default function AccountPage() {
{t('common.account.email')}
-
+
{userProfile.email}
@@ -162,14 +162,14 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('common.account.password')}
-
{t('common.account.passwordTip')}
+
{t('common.account.password')}
+
{t('common.account.passwordTip')}
) } -
+
{t('common.account.langGeniusAccount')}
{t('common.account.langGeniusAccountTip')}
@@ -181,7 +181,7 @@ export default function AccountPage() { wrapperClassName='mt-2' /> )} - {!IS_CE_EDITION && } + {!IS_CE_EDITION && }
{ editNameModalVisible && ( @@ -190,7 +190,7 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className={s.modal} > -
{t('common.account.editName')}
+
{t('common.account.editName')}
{t('common.account.name')}
-
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
{userProfile.is_password_set && ( <>
{t('common.account.currentPassword')}
@@ -242,7 +242,7 @@ export default function AccountPage() {
)} -
+
{userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')}
@@ -261,7 +261,7 @@ export default function AccountPage() {
-
{t('common.account.confirmPassword')}
+
{t('common.account.confirmPassword')}
-
+
{t('common.account.deleteTip')}
{t('common.account.deleteConfirmTip')}
-
{`${t('common.account.delete')}: ${userProfile.email}`}
+
{`${t('common.account.delete')}: ${userProfile.email}`}
} confirmText={t('common.operation.ok') as string} diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx index 544e43ab27f99f..8fdecc07bf867b 100644 --- a/web/app/account/avatar.tsx +++ b/web/app/account/avatar.tsx @@ -40,9 +40,9 @@ export default function AppSelector() { className={` inline-flex items-center rounded-[20px] p-1x text-sm - text-gray-700 hover:bg-gray-200 + text-text-primary mobile:px-1 - ${open && 'bg-gray-200'} + ${open && 'bg-components-panel-bg-blur'} `} > @@ -60,7 +60,7 @@ export default function AppSelector() { @@ -78,10 +78,10 @@ export default function AppSelector() {
handleLogout()}>
- -
{t('common.userProfile.logout')}
+ +
{t('common.userProfile.logout')}
diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx index 5aa8b05cbfd07b..11a6abeab40782 100644 --- a/web/app/account/layout.tsx +++ b/web/app/account/layout.tsx @@ -21,7 +21,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
-
+
{children}
diff --git a/web/app/components/base/modal/index.tsx b/web/app/components/base/modal/index.tsx index 5b8c4be4b824e2..3040cdb00b502a 100644 --- a/web/app/components/base/modal/index.tsx +++ b/web/app/components/base/modal/index.tsx @@ -1,6 +1,6 @@ import { Dialog, Transition } from '@headlessui/react' import { Fragment } from 'react' -import { XMarkIcon } from '@heroicons/react/24/outline' +import { RiCloseLine } from '@remixicon/react' import classNames from '@/utils/classnames' // https://headlessui.com/react/dialog @@ -39,7 +39,7 @@ export default function Modal({ leaveFrom="opacity-100" leaveTo="opacity-0" > -
+
{title && {title} } - {description && + {description && {description} } {closable - &&
- + { e.stopPropagation() onClose() diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index c70cf246619c9c..221d70355ffb33 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -2,7 +2,8 @@ import type { FC } from 'react' import React, { Fragment, useEffect, useState } from 'react' import { Combobox, Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' +import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' +import { RiCheckLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import classNames from '@/utils/classnames' import { @@ -152,7 +153,7 @@ const Select: FC = ({ 'absolute inset-y-0 right-0 flex items-center pr-4 text-gray-700', )} > -
@@ -113,7 +112,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='/account' target='_self' rel='noopener noreferrer'>
{t('common.account.account')}
- + @@ -127,7 +126,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href={mailToSupport(userProfile.email, plan.type, langeniusVersionInfo.current_version)} target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.emailSupport')}
- +
} @@ -136,7 +135,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='https://github.com/langgenius/dify/discussions/categories/feedbacks' target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.communityFeedback')}
- +
@@ -145,7 +144,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='https://discord.gg/5AEfbxcd9k' target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.community')}
- +
@@ -156,7 +155,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { } target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.helpCenter')}
- +
@@ -165,7 +164,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='https://roadmap.dify.ai' target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.roadmap')}
- +
{ @@ -174,7 +173,7 @@ export default function AppSelector({ isMobile }: IAppSelector) {
setAboutVisible(true)}>
{t('common.userProfile.about')}
-
{langeniusVersionInfo.current_version}
+
{langeniusVersionInfo.current_version}
@@ -185,10 +184,10 @@ export default function AppSelector({ isMobile }: IAppSelector) {
handleLogout()}>
-
{t('common.userProfile.logout')}
- +
{t('common.userProfile.logout')}
+
diff --git a/web/app/components/header/account-setting/collapse/index.tsx b/web/app/components/header/account-setting/collapse/index.tsx index a70dca16e5c938..d0068dabed7810 100644 --- a/web/app/components/header/account-setting/collapse/index.tsx +++ b/web/app/components/header/account-setting/collapse/index.tsx @@ -25,18 +25,18 @@ const Collapse = ({ const toggle = () => setOpen(!open) return ( -
-
+
+
{title} { open - ? - : + ? + : }
{ open && ( -
+
{ items.map(item => (
onSelect && onSelect(item)}> diff --git a/web/app/components/header/account-setting/data-source-page/index.tsx b/web/app/components/header/account-setting/data-source-page/index.tsx index c3da977ca4e203..93dc2db8540cf3 100644 --- a/web/app/components/header/account-setting/data-source-page/index.tsx +++ b/web/app/components/header/account-setting/data-source-page/index.tsx @@ -12,7 +12,6 @@ export default function DataSourcePage() { return (
-
{t('common.dataSource.add')}
diff --git a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx b/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx index 2a05808e2ae20f..b7fd8193e2d86a 100644 --- a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx +++ b/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx @@ -44,22 +44,22 @@ const ConfigItem: FC = ({ const onChangeAuthorizedPage = notionActions?.onChangeAuthorizedPage || function () { } return ( -
+
-
{payload.name}
+
{payload.name}
{ payload.isActive - ? + ? : } -
+
{ payload.isActive ? t(isNotion ? 'common.dataSource.notion.connected' : 'common.dataSource.website.active') : t(isNotion ? 'common.dataSource.notion.disconnected' : 'common.dataSource.website.inactive') }
-
+
{isNotion && ( = ({ { isWebsite && !readOnly && ( -
- +
+
) } diff --git a/web/app/components/header/account-setting/data-source-page/panel/index.tsx b/web/app/components/header/account-setting/data-source-page/panel/index.tsx index 4a810020b440ed..8d2ec0a8ca4c93 100644 --- a/web/app/components/header/account-setting/data-source-page/panel/index.tsx +++ b/web/app/components/header/account-setting/data-source-page/panel/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { PlusIcon } from '@heroicons/react/24/solid' +import { RiAddLine } from '@remixicon/react' import type { ConfigItemType } from './config-item' import ConfigItem from './config-item' @@ -41,12 +41,12 @@ const Panel: FC = ({ const isWebsite = type === DataSourceType.website return ( -
+
-
+
-
{t(`common.dataSource.${type}.title`)}
+
{t(`common.dataSource.${type}.title`)}
{isWebsite && (
{t('common.dataSource.website.with')} { provider === DataSourceProvider.fireCrawl ? '🔥 Firecrawl' : 'Jina Reader'} @@ -55,7 +55,7 @@ const Panel: FC = ({
{ !isConfigured && ( -
+
{t(`common.dataSource.${type}.description`)}
) @@ -81,13 +81,13 @@ const Panel: FC = ({ <> {isSupportList &&
- - {t('common.dataSource.notion.addWorkspace')} + + {t('common.dataSource.connect')}
} ) @@ -98,8 +98,8 @@ const Panel: FC = ({ {isWebsite && !isConfigured && (
= ({ isConfigured && ( <>
-
+
{isNotion ? t('common.dataSource.notion.connectedWorkspace') : t('common.dataSource.website.configuredCrawlers')}
-
+
{ diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index d829f6b77b0cc9..4be7ec6ab7ee6f 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -152,14 +152,14 @@ export default function AccountSetting({ wrapperClassName='pt-[60px]' >
-
-
{t('common.userProfile.settings')}
+
+
{t('common.userProfile.settings')}
{ menuItems.map(menuItem => (
{!isCurrentWorkspaceDatasetOperator && ( -
{menuItem.name}
+
{menuItem.name}
)}
{ @@ -168,7 +168,7 @@ export default function AccountSetting({ key={item.key} className={` flex items-center h-[37px] mb-[2px] text-sm cursor-pointer rounded-lg - ${activeMenu === item.key ? 'font-semibold text-primary-600 bg-primary-50' : 'font-light text-gray-700'} + ${activeMenu === item.key ? 'system-sm-semibold text-components-menu-item-text-active bg-state-base-active' : 'system-sm-medium text-components-menu-item-text'} `} title={item.name} onClick={() => setActiveMenu(item.key)} @@ -185,7 +185,7 @@ export default function AccountSetting({
-
+
{activeItem?.name}
{ activeItem?.description && ( @@ -193,8 +193,8 @@ export default function AccountSetting({ ) }
-
- +
+
diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index fc8db8681359af..7d3e09fc2103ae 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -13,7 +13,7 @@ import { timezones } from '@/utils/timezone' import { languages } from '@/i18n/language' const titleClassName = ` - mb-2 text-sm font-medium text-gray-900 + mb-2 system-sm-semibold text-text-secondary ` export default function LanguagePage() { diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index 03d65af7a45147..c2b722b4a7af65 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -85,32 +85,32 @@ const MembersPage = () => {
-
-
{t('common.members.name')}
-
{t('common.members.lastActive')}
-
{t('common.members.role')}
+
+
{t('common.members.name')}
+
{t('common.members.lastActive')}
+
{t('common.members.role')}
{ accounts.map(account => ( -
+
-
+
{account.name} - {account.status === 'pending' && {t('common.members.pending')}} - {userProfile.email === account.email && {t('common.members.you')}} + {account.status === 'pending' && {t('common.members.pending')}} + {userProfile.email === account.email && {t('common.members.you')}}
-
{account.email}
+
{account.email}
-
{dayjs(Number((account.last_active_at || account.created_at)) * 1000).locale(locale === 'zh-Hans' ? 'zh-cn' : 'en').fromNow()}
+
{dayjs(Number((account.last_active_at || account.created_at)) * 1000).locale(locale === 'zh-Hans' ? 'zh-cn' : 'en').fromNow()}
{ ((isCurrentWorkspaceOwner && account.role !== 'owner') || (isCurrentWorkspaceManager && !['owner', 'admin'].includes(account.role))) ? - :
{RoleMap[account.role] || RoleMap.normal}
+ :
{RoleMap[account.role] || RoleMap.normal}
}
diff --git a/web/app/components/header/header-wrapper.tsx b/web/app/components/header/header-wrapper.tsx index 52728bea875e6a..dd0ec77b8218a9 100644 --- a/web/app/components/header/header-wrapper.tsx +++ b/web/app/components/header/header-wrapper.tsx @@ -11,7 +11,7 @@ const HeaderWrapper = ({ children, }: HeaderWrapperProps) => { const pathname = usePathname() - const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools', '/account'].includes(pathname) + const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools'].includes(pathname) return (
{ // eslint-disable-next-line react-hooks/exhaustive-deps }, [selectedSegment]) return ( -
+
{isMobile &&
Date: Mon, 23 Dec 2024 13:53:05 +0800 Subject: [PATCH 69/91] fix: fix update external dataset error in dataset list (#11989) --- api/services/dataset_service.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index a1014e8e0ad908..4e99c73ad4787a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -231,11 +231,15 @@ def update_dataset(dataset_id, data, user): DatasetService.check_dataset_permission(dataset, user) if dataset.provider == "external": - dataset.retrieval_model = data.get("external_retrieval_model", None) + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", "") + permission = data.get("permission") + if permission: + dataset.permission = permission external_knowledge_id = data.get("external_knowledge_id", None) - dataset.permission = data.get("permission") db.session.add(dataset) if not external_knowledge_id: raise ValueError("External knowledge id is required.") From 4584eb305825a78cff8b6663d3fe5d840c5bd1fc Mon Sep 17 00:00:00 2001 From: yagiyuki <68677393+yagiyuki@users.noreply.github.com> Date: Mon, 23 Dec 2024 14:53:46 +0900 Subject: [PATCH 70/91] Add custom to file types (#11966) Co-authored-by: yagiyuki --- .../components/develop/template/template_advanced_chat.en.mdx | 1 + .../components/develop/template/template_advanced_chat.ja.mdx | 1 + .../components/develop/template/template_advanced_chat.zh.mdx | 1 + web/app/components/develop/template/template_workflow.en.mdx | 1 + web/app/components/develop/template/template_workflow.ja.mdx | 1 + web/app/components/develop/template/template_workflow.zh.mdx | 1 + 6 files changed, 6 insertions(+) diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 5106b6f36a8757..c7b92acc52a3d9 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -71,6 +71,7 @@ Chat applications support session persistence, allowing previous chat history to - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') - `audio` ('MP3', 'M4A', 'WAV', 'WEBM', 'AMR') - `video` ('MP4', 'MOV', 'MPEG', 'MPGA') + - `custom` (Other file types) - `transfer_method` (string) Transfer method, `remote_url` for image URL / `local_file` for file upload - `url` (string) Image URL (when the transfer method is `remote_url`) - `upload_file_id` (string) Uploaded file ID, which must be obtained by uploading through the File Upload API in advance (when the transfer method is `local_file`) diff --git a/web/app/components/develop/template/template_advanced_chat.ja.mdx b/web/app/components/develop/template/template_advanced_chat.ja.mdx index cf65a29b446de1..67e8d8f7fec888 100644 --- a/web/app/components/develop/template/template_advanced_chat.ja.mdx +++ b/web/app/components/develop/template/template_advanced_chat.ja.mdx @@ -71,6 +71,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') - `audio` ('MP3', 'M4A', 'WAV', 'WEBM', 'AMR') - `video` ('MP4', 'MOV', 'MPEG', 'MPGA') + - `custom` (他のファイルタイプ) - `transfer_method` (string) 転送方法、画像URLの場合は`remote_url` / ファイルアップロードの場合は`local_file` - `url` (string) 画像URL(転送方法が`remote_url`の場合) - `upload_file_id` (string) アップロードされたファイルID、事前にファイルアップロードAPIを通じて取得する必要があります(転送方法が`local_file`の場合) diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 662309525b9264..3bb17d850299a0 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -68,6 +68,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `image` 具体类型包含:'JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG' - `audio` 具体类型包含:'MP3', 'M4A', 'WAV', 'WEBM', 'AMR' - `video` 具体类型包含:'MP4', 'MOV', 'MPEG', 'MPGA' + - `custom` 具体类型包含:其他文件类型 - `transfer_method` (string) 传递方式: - `remote_url`: 图片地址。 - `local_file`: 上传文件。 diff --git a/web/app/components/develop/template/template_workflow.en.mdx b/web/app/components/develop/template/template_workflow.en.mdx index cfa5a60d47a0d2..58c533c60b1375 100644 --- a/web/app/components/develop/template/template_workflow.en.mdx +++ b/web/app/components/develop/template/template_workflow.en.mdx @@ -60,6 +60,7 @@ Workflow applications offers non-session support and is ideal for translation, a - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') - `audio` ('MP3', 'M4A', 'WAV', 'WEBM', 'AMR') - `video` ('MP4', 'MOV', 'MPEG', 'MPGA') + - `custom` (Other file types) - `transfer_method` (string) Transfer method, `remote_url` for image URL / `local_file` for file upload - `url` (string) Image URL (when the transfer method is `remote_url`) - `upload_file_id` (string) Uploaded file ID, which must be obtained by uploading through the File Upload API in advance (when the transfer method is `local_file`) diff --git a/web/app/components/develop/template/template_workflow.ja.mdx b/web/app/components/develop/template/template_workflow.ja.mdx index b6f8fb543f7ffa..2653b4913d7d06 100644 --- a/web/app/components/develop/template/template_workflow.ja.mdx +++ b/web/app/components/develop/template/template_workflow.ja.mdx @@ -60,6 +60,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') - `audio` ('MP3', 'M4A', 'WAV', 'WEBM', 'AMR') - `video` ('MP4', 'MOV', 'MPEG', 'MPGA') + - `custom` (他のファイルタイプ) - `transfer_method` (string) 転送方法、画像URLの場合は`remote_url` / ファイルアップロードの場合は`local_file` - `url` (string) 画像URL(転送方法が`remote_url`の場合) - `upload_file_id` (string) アップロードされたファイルID、事前にファイルアップロードAPIを通じて取得する必要があります(転送方法が`local_file`の場合) diff --git a/web/app/components/develop/template/template_workflow.zh.mdx b/web/app/components/develop/template/template_workflow.zh.mdx index 9cef3d18a5e945..ddffc0f02d5070 100644 --- a/web/app/components/develop/template/template_workflow.zh.mdx +++ b/web/app/components/develop/template/template_workflow.zh.mdx @@ -58,6 +58,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 - `image` 具体类型包含:'JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG' - `audio` 具体类型包含:'MP3', 'M4A', 'WAV', 'WEBM', 'AMR' - `video` 具体类型包含:'MP4', 'MOV', 'MPEG', 'MPGA' + - `custom` 具体类型包含:其他文件类型 - `transfer_method` (string) 传递方式,`remote_url` 图片地址 / `local_file` 上传文件 - `url` (string) 图片地址(仅当传递方式为 `remote_url` 时) - `upload_file_id` (string) (string) 上传文件 ID(仅当传递方式为 `local_file` 时) From 4b1e13e9824113c0f4f8d48fc53a93f76449d640 Mon Sep 17 00:00:00 2001 From: JasonVV Date: Mon, 23 Dec 2024 14:30:04 +0800 Subject: [PATCH 71/91] Fix 11979 (#11984) --- api/core/workflow/nodes/llm/node.py | 30 +++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index bcba3f5a4d55f1..55fac45576c821 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -860,14 +860,16 @@ def _handle_list_messages( ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] for message in messages: - contents: list[PromptMessageContent] = [] if message.edition_type == "jinja2": result_text = _render_jinja2_message( template=message.jinja2_text or "", jinjia2_variables=jinja2_variables, variable_pool=variable_pool, ) - contents.append(TextPromptMessageContent(data=result_text)) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) else: # Get segment group from basic message if context: @@ -877,6 +879,7 @@ def _handle_list_messages( segment_group = variable_pool.convert_template(template) # Process segments for images + file_contents = [] for segment in segment_group.value: if isinstance(segment, ArrayFileSegment): for file in segment.value: @@ -884,20 +887,27 @@ def _handle_list_messages( file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) - contents.append(file_content) + file_contents.append(file_content) elif isinstance(segment, FileSegment): file = segment.value if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) - contents.append(file_content) - else: - plain_text = segment.markdown.strip() - if plain_text: - contents.append(TextPromptMessageContent(data=plain_text)) - prompt_message = _combine_message_content_with_role(contents=contents, role=message.role) - prompt_messages.append(prompt_message) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) return prompt_messages From c1aa55f3ea8ee674aa00689e97ba9ace60f5f295 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Mon, 23 Dec 2024 14:32:11 +0800 Subject: [PATCH 72/91] fix: remove the unused retry index field (#11903) Signed-off-by: -LAN- Co-authored-by: Novice Lee Co-authored-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 61 ++++---- .../apps/workflow/generate_task_pipeline.py | 62 ++++---- api/core/app/apps/workflow_app_runner.py | 92 ++++++------ api/core/app/entities/queue_entities.py | 68 ++++----- .../task_pipeline/workflow_cycle_manage.py | 5 +- api/core/helper/ssrf_proxy.py | 4 +- .../workflow/graph_engine/entities/event.py | 5 +- .../workflow/graph_engine/graph_engine.py | 7 +- api/core/workflow/nodes/event/event.py | 8 +- .../workflow/nodes/http_request/executor.py | 4 +- api/fields/workflow_run_fields.py | 11 +- ...dd_retry_index_field_to_node_execution_.py | 18 ++- ..._remove_workflow_node_executions_retry_.py | 34 +++++ api/models/model.py | 5 +- api/models/workflow.py | 1 - api/services/workflow_service.py | 141 ++++++------------ 16 files changed, 257 insertions(+), 269 deletions(-) create mode 100644 api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 7c8f9f3813b378..635e482ad980ed 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -180,7 +180,7 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise ValueError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -291,9 +291,27 @@ def _process_stream_response( yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not workflow_run: + raise ValueError("workflow run not initialized.") + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) @@ -331,63 +349,48 @@ def _process_stream_response( if response: yield response - elif isinstance( - event, - QueueNodeRetryEvent, - ): - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - - if response: - yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("workflow run not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -406,10 +409,10 @@ def _process_stream_response( self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_partial_success( workflow_run=workflow_run, @@ -429,10 +432,10 @@ def _process_stream_response( self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, @@ -522,7 +525,7 @@ def _process_stream_response( yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) if output_moderation_answer: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index d279002285f4c2..c47b38f5600f4d 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -155,7 +155,7 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise ValueError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -218,7 +218,7 @@ def _wrapper_process_stream_response( break else: yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) - except Exception as e: + except Exception: logger.exception(f"Fails to get audio trunk, task_id: {task_id}") break if tts_publisher: @@ -254,9 +254,27 @@ def _process_stream_response( yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not workflow_run: + raise ValueError("workflow run not initialized.") + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) @@ -289,64 +307,48 @@ def _process_stream_response( ) if node_failed_response: yield node_failed_response - elif isinstance( - event, - QueueNodeRetryEvent, - ): - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - - if response: - yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -366,10 +368,10 @@ def _process_stream_response( ) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_partial_success( workflow_run=workflow_run, @@ -390,10 +392,10 @@ def _process_stream_response( ) elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 2fbf711175aab5..885283504b4175 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -188,6 +188,41 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) ) elif isinstance(event, GraphRunFailedEvent): self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) + elif isinstance(event, NodeRunRetryEvent): + node_run_result = event.route_node_state.node_run_result + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + else: + inputs = {} + process_data = {} + outputs = {} + execution_metadata = {} + self._publish_event( + QueueNodeRetryEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, + inputs=inputs, + process_data=process_data, + outputs=outputs, + error=event.error, + execution_metadata=execution_metadata, + retry_index=event.retry_index, + ) + ) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -207,6 +242,17 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) ) ) elif isinstance(event, NodeRunSucceededEvent): + node_run_result = event.route_node_state.node_run_result + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + else: + inputs = {} + process_data = {} + outputs = {} + execution_metadata = {} self._publish_event( QueueNodeSucceededEvent( node_execution_id=event.id, @@ -218,18 +264,10 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result - else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result - else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result - else {}, - execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result - else {}, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata, in_iteration_id=event.in_iteration_id, ) ) @@ -422,36 +460,6 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) error=event.error if isinstance(event, IterationRunFailedEvent) else None, ) ) - elif isinstance(event, NodeRunRetryEvent): - self._publish_event( - QueueNodeRetryEvent( - node_execution_id=event.id, - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, - start_at=event.start_at, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result - else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result - else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result - else {}, - error=event.error, - execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result - else {}, - in_iteration_id=event.in_iteration_id, - retry_index=event.retry_index, - start_index=event.start_index, - ) - ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 49b7e80246a3f2..d73c2eb53bfcd7 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional @@ -85,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None class QueueIterationNextEvent(AppQueueEvent): @@ -139,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - metadata: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None steps: int = 0 error: Optional[str] = None @@ -304,9 +305,9 @@ class QueueNodeSucceededEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None @@ -314,35 +315,18 @@ class QueueNodeSucceededEvent(AppQueueEvent): iteration_duration_map: Optional[dict[str, float]] = None -class QueueNodeRetryEvent(AppQueueEvent): +class QueueNodeRetryEvent(QueueNodeStartedEvent): """QueueNodeRetryEvent entity""" event: QueueEvent = QueueEvent.RETRY - node_execution_id: str - node_id: str - node_type: NodeType - node_data: BaseNodeData - parallel_id: Optional[str] = None - """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None - """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None - """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None - """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None - """iteration id if node is in iteration""" - start_at: datetime - - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str retry_index: int # retry index - start_index: int # start index class QueueNodeInIterationFailedEvent(AppQueueEvent): @@ -368,10 +352,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str @@ -399,10 +383,10 @@ class QueueNodeExceptionEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str @@ -430,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 0e6425fa14e8cf..5061804310d236 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -445,6 +445,7 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.workflow_id = workflow_run.workflow_id workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id workflow_node_execution.node_execution_id = event.node_execution_id workflow_node_execution.node_id = event.node_id workflow_node_execution.node_type = event.node_type.value @@ -461,9 +462,11 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.execution_metadata = json.dumps( { NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, } ) - workflow_node_execution.index = event.start_index + workflow_node_execution.index = event.node_run_index db.session.add(workflow_node_execution) db.session.commit() diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ce15a9667d8d18..424983a819ec28 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -45,6 +45,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) retries = 0 + stream = kwargs.pop("stream", False) while retries <= max_retries: try: if dify_config.SSRF_PROXY_ALL_URL: @@ -64,11 +65,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): except httpx.RequestError as e: logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + if max_retries == 0: + raise retries += 1 if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) - raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 999715316494fb..d591b68e7e72be 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent): error: str = Field(..., description="failed reason") - exceptions_count: Optional[int] = Field(description="exception count", default=0) + exceptions_count: int = Field(description="exception count", default=0) class GraphRunPartialSucceededEvent(BaseGraphEvent): @@ -97,11 +97,10 @@ class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") -class NodeRunRetryEvent(BaseNodeEvent): +class NodeRunRetryEvent(NodeRunStartedEvent): error: str = Field(..., description="error") retry_index: int = Field(..., description="which retry attempt is about to be performed") start_at: datetime = Field(..., description="retry start time") - start_index: int = Field(..., description="retry start index") ########################################### diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e292b099684b7b..d7d33c65fcdb38 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -641,7 +641,6 @@ def _run_node( run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED if node_instance.should_retry and retries < max_retries: retries += 1 - self.graph_runtime_state.node_run_steps += 1 route_node_state.node_run_result = run_result yield NodeRunRetryEvent( id=node_instance.id, @@ -649,14 +648,14 @@ def _run_node( node_type=node_instance.node_type, node_data=node_instance.node_data, route_node_state=route_node_state, - error=run_result.error, - retry_index=retries, + predecessor_node_id=node_instance.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + error=run_result.error, + retry_index=retries, start_at=retry_start_at, - start_index=self.graph_runtime_state.node_run_steps, ) time.sleep(retry_interval) continue diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index 0dc35e7d7740ef..137b47655102af 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -39,15 +39,9 @@ class RunRetryEvent(BaseModel): start_at: datetime = Field(..., description="Retry start time") -class SingleStepRetryEvent(BaseModel): +class SingleStepRetryEvent(NodeRunResult): """Single step retry event""" status: str = WorkflowNodeExecutionStatus.RETRY.value - inputs: dict | None = Field(..., description="input") - error: str = Field(..., description="error") - outputs: dict = Field(..., description="output") - retry_index: int = Field(..., description="Retry attempt number") - error: str = Field(..., description="error") elapsed_time: float = Field(..., description="elapsed time") - execution_metadata: dict | None = Field(..., description="execution metadata") diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index b96402a76af0bc..3b7e19331993ee 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -249,9 +249,7 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: # request_args = {k: v for k, v in request_args.items() if v is not None} try: response = getattr(ssrf_proxy, self.method)(**request_args) - except httpx.RequestError as e: - raise HttpRequestNodeError(str(e)) - except ssrf_proxy.MaxRetriesExceededError as e: + except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: raise HttpRequestNodeError(str(e)) return response diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 7c01ffc2c62e39..74fdf8bd97b23a 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -82,13 +82,15 @@ } retry_event_field = { - "error": fields.String, - "retry_index": fields.Integer, - "inputs": fields.Raw(attribute="inputs"), "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), "status": fields.String, + "inputs": fields.Raw(attribute="inputs"), + "process_data": fields.Raw(attribute="process_data"), "outputs": fields.Raw(attribute="outputs"), + "metadata": fields.Raw(attribute="metadata"), + "llm_usage": fields.Raw(attribute="llm_usage"), + "error": fields.String, + "retry_index": fields.Integer, } @@ -112,7 +114,6 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "finished_at": TimestampField, - "retry_events": fields.List(fields.Nested(retry_event_field)), } workflow_run_node_execution_list_fields = { diff --git a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py index 3254c23c96192d..814dec423c63c4 100644 --- a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py +++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py @@ -1,9 +1,7 @@ """add retry_index field to node-execution model - Revision ID: e1944c35e15e Revises: 11b07f66c737 Create Date: 2024-12-20 06:28:30.287197 - """ from alembic import op import models as models @@ -19,15 +17,21 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: - batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # We don't need these fields anymore, but this file is already merged into the main branch, + # so we need to keep this file for the sake of history, and this change will be reverted in the next migration. + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + pass # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: - batch_op.drop_column('retry_index') + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.drop_column('retry_index') + pass - # ### end Alembic commands ### + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py new file mode 100644 index 00000000000000..ea129d15f7e6e6 --- /dev/null +++ b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py @@ -0,0 +1,34 @@ +"""remove workflow_node_executions.retry_index if exists + +Revision ID: d7999dfa4aae +Revises: e1944c35e15e +Create Date: 2024-12-23 11:54:15.344543 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy import inspect + + +# revision identifiers, used by Alembic. +revision = 'd7999dfa4aae' +down_revision = 'e1944c35e15e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Check if column exists before attempting to remove it + conn = op.get_bind() + inspector = inspect(conn) + has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')] + + if has_column: + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_column('retry_index') + + +def downgrade(): + # No downgrade needed as we don't want to restore the column + pass diff --git a/api/models/model.py b/api/models/model.py index 04bb0a947a8ec3..1417298c79c0a2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import sqlalchemy as sa from flask import request @@ -24,6 +24,9 @@ from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from .workflow import Workflow + class DifySetup(db.Model): __tablename__ = "dify_setups" diff --git a/api/models/workflow.py b/api/models/workflow.py index 7896339f37d00d..d5be949bf44f2a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -641,7 +641,6 @@ class WorkflowNodeExecution(db.Model): created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) - retry_index = db.Column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ead552d6c2e83e..84768d5af053e4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,7 +15,6 @@ from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent -from core.workflow.nodes.event.event import SingleStepRetryEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -221,99 +220,56 @@ def run_draft_workflow_node( # run draft workflow node start_at = time.perf_counter() - retries = 0 - max_retries = 0 - should_retry = True - retry_events = [] try: - while retries <= max_retries and should_retry: - retry_start_at = time.perf_counter() - node_instance, generator = WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) - max_retries = ( - node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0 - ) - retry_interval = node_instance.node_data.retry_config.retry_interval_seconds - node_run_result: NodeRunResult | None = None - for event in generator: - if isinstance(event, RunCompletedEvent): - node_run_result = event.run_result - - # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) - break - - if not node_run_result: - raise ValueError("Node run failed with no run result") - # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: - if ( - retries == max_retries - and node_instance.node_type == NodeType.HTTP_REQUEST - and node_run_result.outputs - and not node_instance.should_continue_on_error - ): - node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED - should_retry = False - else: - if node_instance.should_retry: - node_run_result.status = WorkflowNodeExecutionStatus.RETRY - retries += 1 - node_run_result.retry_index = retries - retry_events.append( - SingleStepRetryEvent( - inputs=WorkflowEntry.handle_special_values(node_run_result.inputs) - if node_run_result.inputs - else None, - error=node_run_result.error, - outputs=WorkflowEntry.handle_special_values(node_run_result.outputs) - if node_run_result.outputs - else None, - retry_index=node_run_result.retry_index, - elapsed_time=time.perf_counter() - retry_start_at, - execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata) - if node_run_result.metadata - else None, - ) - ) - time.sleep(retry_interval) - else: - should_retry = False - if node_instance.should_continue_on_error: - node_error_args = { - "status": WorkflowNodeExecutionStatus.EXCEPTION, - "error": node_run_result.error, - "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, - } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - **node_instance.node_data.default_value_dict, - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) - else: - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) - run_succeeded = node_run_result.status in ( - WorkflowNodeExecutionStatus.SUCCEEDED, - WorkflowNodeExecutionStatus.EXCEPTION, - ) - error = node_run_result.error if not run_succeeded else None + node_instance, generator = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + node_instance = cast(BaseNode[BaseNodeData], node_instance) + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError("Node run failed with no run result") + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) + error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: node_instance = e.node_instance run_succeeded = False @@ -362,7 +318,6 @@ def run_draft_workflow_node( db.session.add(workflow_node_execution) db.session.commit() - workflow_node_execution.retry_events = retry_events return workflow_node_execution From 453f324f54db9c1099372806512987a97047cd07 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Mon, 23 Dec 2024 15:06:01 +0800 Subject: [PATCH 73/91] fix: retry node in iteration logs wrong (#11995) Co-authored-by: Novice Lee --- .../app/task_pipeline/workflow_cycle_manage.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 5061804310d236..72e4c796c34546 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -438,6 +438,16 @@ def _handle_workflow_node_execution_retried( elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) + origin_metadata = { + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + } + merged_metadata = ( + {**jsonable_encoder(event.execution_metadata), **origin_metadata} + if event.execution_metadata is not None + else origin_metadata + ) + execution_metadata = json.dumps(merged_metadata) workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.tenant_id = workflow_run.tenant_id @@ -459,13 +469,7 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - } - ) + workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index db.session.add(workflow_node_execution) From 70dd69d533b58f4837eea00e87a789b8653931f9 Mon Sep 17 00:00:00 2001 From: yihong Date: Mon, 23 Dec 2024 15:09:02 +0800 Subject: [PATCH 74/91] fix: Multiple Paths Between IF/ELSE Branches Invalidate Conditions (#11544) Signed-off-by: yihong0618 --- api/core/workflow/nodes/answer/answer_stream_generate_router.py | 2 ++ api/core/workflow/nodes/end/end_stream_generate_router.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 1b948bf59203b7..7d652d39f70ef4 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -147,6 +147,8 @@ def _recursive_fetch_answer_dependencies( reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id + if source_node_id not in node_id_config_mapping: + continue source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") source_node_data = node_id_config_mapping[source_node_id].get("data", {}) if ( diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index ea8b6b50420c99..0db1ba9f09d36e 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -135,6 +135,8 @@ def _recursive_fetch_end_dependencies( reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id + if source_node_id not in node_id_config_mapping: + continue source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in { NodeType.IF_ELSE.value, From 03548cdfbc0ed919c7db3d1626498727f022adc8 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 15:23:11 +0800 Subject: [PATCH 75/91] fix: handle broader request exceptions in OAuth process (#11997) Signed-off-by: -LAN- --- api/controllers/console/auth/oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5de8c6766d35a8..b9188aa0798ea2 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -76,7 +76,7 @@ def get(self, provider: str): try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) - except requests.exceptions.HTTPError as e: + except requests.exceptions.RequestException as e: logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") return {"error": "OAuth process failed"}, 400 From 4e3d732934294ab2528b09855239375cc118b0cc Mon Sep 17 00:00:00 2001 From: Shun Miyazawa <34241526+miya@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:27:49 +0900 Subject: [PATCH 76/91] feat: Warning on invite modal when mail setup is incomplete (#11809) --- .../account-setting/members-page/index.tsx | 3 ++- .../members-page/invite-modal/index.tsx | 21 ++++++++++++++++++- web/i18n/de-DE/common.ts | 1 + web/i18n/en-US/common.ts | 1 + web/i18n/es-ES/common.ts | 1 + web/i18n/fa-IR/common.ts | 1 + web/i18n/fr-FR/common.ts | 1 + web/i18n/hi-IN/common.ts | 1 + web/i18n/it-IT/common.ts | 1 + web/i18n/ja-JP/common.ts | 1 + web/i18n/ko-KR/common.ts | 1 + web/i18n/pl-PL/common.ts | 1 + web/i18n/pt-BR/common.ts | 1 + web/i18n/ro-RO/common.ts | 1 + web/i18n/ru-RU/common.ts | 1 + web/i18n/sl-SI/common.ts | 1 + web/i18n/th-TH/common.ts | 1 + web/i18n/tr-TR/common.ts | 1 + web/i18n/uk-UA/common.ts | 1 + web/i18n/vi-VN/common.ts | 1 + web/i18n/zh-Hans/common.ts | 1 + web/i18n/zh-Hant/common.ts | 1 + 22 files changed, 42 insertions(+), 2 deletions(-) diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index c2b722b4a7af65..808da454d1e1f8 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -34,7 +34,7 @@ const MembersPage = () => { } const { locale } = useContext(I18n) - const { userProfile, currentWorkspace, isCurrentWorkspaceOwner, isCurrentWorkspaceManager } = useAppContext() + const { userProfile, currentWorkspace, isCurrentWorkspaceOwner, isCurrentWorkspaceManager, systemFeatures } = useAppContext() const { data, mutate } = useSWR({ url: '/workspaces/current/members' }, fetchMembers) const [inviteModalVisible, setInviteModalVisible] = useState(false) const [invitationResults, setInvitationResults] = useState([]) @@ -122,6 +122,7 @@ const MembersPage = () => { { inviteModalVisible && ( setInviteModalVisible(false)} onSend={(invitationResults) => { setInvitedModalVisible(true) diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index 7d434953625d61..197e3ee8673827 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -4,6 +4,7 @@ import { useContext } from 'use-context-selector' import { XMarkIcon } from '@heroicons/react/24/outline' import { useTranslation } from 'react-i18next' import { ReactMultiEmail } from 'react-multi-email' +import { RiErrorWarningFill } from '@remixicon/react' import RoleSelector from './role-selector' import s from './index.module.css' import cn from '@/utils/classnames' @@ -17,11 +18,13 @@ import I18n from '@/context/i18n' import 'react-multi-email/dist/style.css' type IInviteModalProps = { + isEmailSetup: boolean onCancel: () => void onSend: (invitationResults: InvitationResult[]) => void } const InviteModal = ({ + isEmailSetup, onCancel, onSend, }: IInviteModalProps) => { @@ -59,7 +62,23 @@ const InviteModal = ({
{t('common.members.inviteTeamMember')}
-
{t('common.members.inviteTeamMemberTip')}
+
{t('common.members.inviteTeamMemberTip')}
+ {!isEmailSetup && ( +
+
+
+
+
+ +
+
+ {t('common.members.emailNotSetup')} +
+
+
+
+ )} +
{t('common.members.email')}
diff --git a/web/i18n/de-DE/common.ts b/web/i18n/de-DE/common.ts index 1d7ca955fafbd3..f438b4f018490d 100644 --- a/web/i18n/de-DE/common.ts +++ b/web/i18n/de-DE/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Kann Apps erstellen & bearbeiten', inviteTeamMember: 'Teammitglied hinzufügen', inviteTeamMemberTip: 'Sie können direkt nach der Anmeldung auf Ihre Teamdaten zugreifen.', + emailNotSetup: 'E-Mail-Server ist nicht eingerichtet, daher können keine Einladungs-E-Mails versendet werden. Bitte informieren Sie die Benutzer über den Einladungslink, der nach der Einladung ausgestellt wird.', email: 'E-Mail', emailInvalid: 'Ungültiges E-Mail-Format', emailPlaceholder: 'Bitte E-Mails eingeben', diff --git a/web/i18n/en-US/common.ts b/web/i18n/en-US/common.ts index ce341a91485e03..ea0e4a88aaf987 100644 --- a/web/i18n/en-US/common.ts +++ b/web/i18n/en-US/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Only can manage the knowledge base', inviteTeamMember: 'Add team member', inviteTeamMemberTip: 'They can access your team data directly after signing in.', + emailNotSetup: 'Email server is not set up, so invitation emails cannot be sent. Please notify users of the invitation link that will be issued after invitation instead.', email: 'Email', emailInvalid: 'Invalid Email Format', emailPlaceholder: 'Please input emails', diff --git a/web/i18n/es-ES/common.ts b/web/i18n/es-ES/common.ts index cc9fb473299ad0..2540632758c05c 100644 --- a/web/i18n/es-ES/common.ts +++ b/web/i18n/es-ES/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Solo puede administrar la base de conocimiento', inviteTeamMember: 'Agregar miembro del equipo', inviteTeamMemberTip: 'Pueden acceder a tus datos del equipo directamente después de iniciar sesión.', + emailNotSetup: 'El servidor de correo no está configurado, por lo que no se pueden enviar correos de invitación. En su lugar, notifique a los usuarios el enlace de invitación que se emitirá después de la invitación.', email: 'Correo electrónico', emailInvalid: 'Formato de correo electrónico inválido', emailPlaceholder: 'Por favor ingresa correos electrónicos', diff --git a/web/i18n/fa-IR/common.ts b/web/i18n/fa-IR/common.ts index 2da6cdee8b8631..deab852ddb31ae 100644 --- a/web/i18n/fa-IR/common.ts +++ b/web/i18n/fa-IR/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'فقط می‌تواند پایگاه دانش را مدیریت کند', inviteTeamMember: 'افزودن عضو تیم', inviteTeamMemberTip: 'آنها می‌توانند پس از ورود به سیستم، مستقیماً به داده‌های تیم شما دسترسی پیدا کنند.', + emailNotSetup: 'سرور ایمیل راه‌اندازی نشده است، بنابراین ایمیل‌های دعوت نمی‌توانند ارسال شوند. لطفاً کاربران را از لینک دعوت که پس از دعوت صادر خواهد شد مطلع کنید。', email: 'ایمیل', emailInvalid: 'فرمت ایمیل نامعتبر است', emailPlaceholder: 'لطفاً ایمیل‌ها را وارد کنید', diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts index 326572916c2f0b..25142c11cc9f69 100644 --- a/web/i18n/fr-FR/common.ts +++ b/web/i18n/fr-FR/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Peut construire des applications, mais ne peut pas gérer les paramètres de l\'équipe', inviteTeamMember: 'Ajouter un membre de l\'équipe', inviteTeamMemberTip: 'Ils peuvent accéder directement à vos données d\'équipe après s\'être connectés.', + emailNotSetup: 'Le serveur de messagerie n\'est pas configuré, les e-mails d\'invitation ne peuvent donc pas être envoyés. Veuillez informer les utilisateurs du lien d\'invitation qui sera émis après l\'invitation.', email: 'Courrier électronique', emailInvalid: 'Format de courriel invalide', emailPlaceholder: 'Veuillez entrer des emails', diff --git a/web/i18n/hi-IN/common.ts b/web/i18n/hi-IN/common.ts index a2c178cb185a49..aabcfc86e653e0 100644 --- a/web/i18n/hi-IN/common.ts +++ b/web/i18n/hi-IN/common.ts @@ -204,6 +204,7 @@ const translation = { inviteTeamMember: 'टीम सदस्य जोड़ें', inviteTeamMemberTip: 'वे साइन इन करने के बाद सीधे आपकी टीम डेटा तक पहुंच सकते हैं।', + emailNotSetup: 'ईमेल सर्वर सेट नहीं है, इसलिए आमंत्रण ईमेल नहीं भेजे जा सकते। कृपया उपयोगकर्ताओं को आमंत्रण के बाद जारी किए जाने वाले आमंत्रण लिंक के बारे में सूचित करें。', email: 'ईमेल', emailInvalid: 'अवैध ईमेल प्रारूप', emailPlaceholder: 'कृपया ईमेल दर्ज करें', diff --git a/web/i18n/it-IT/common.ts b/web/i18n/it-IT/common.ts index 35a01d71147108..4cee6dec50b058 100644 --- a/web/i18n/it-IT/common.ts +++ b/web/i18n/it-IT/common.ts @@ -208,6 +208,7 @@ const translation = { inviteTeamMember: 'Aggiungi membro del team', inviteTeamMemberTip: 'Potranno accedere ai dati del tuo team direttamente dopo aver effettuato l\'accesso.', + emailNotSetup: 'Il server email non è configurato, quindi non è possibile inviare email di invito. Si prega di notificare agli utenti il link di invito che verrà emesso dopo l\'invito.', email: 'Email', emailInvalid: 'Formato Email non valido', emailPlaceholder: 'Per favore inserisci le email', diff --git a/web/i18n/ja-JP/common.ts b/web/i18n/ja-JP/common.ts index fa3fb223f4b65b..9c23cb6f161d7e 100644 --- a/web/i18n/ja-JP/common.ts +++ b/web/i18n/ja-JP/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'ナレッジベースのみを管理できる', inviteTeamMember: 'チームメンバーを招待する', inviteTeamMemberTip: '彼らはサインイン後、直接あなた様のチームデータにアクセスできます。', + emailNotSetup: 'メールサーバーがセットアップされていないので、招待メールを送信することはできません。代わりに招待後に発行される招待リンクをユーザーに通知してください。', email: 'メール', emailInvalid: '無効なメール形式', emailPlaceholder: 'メールを入力してください', diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index ce860e000e3153..a599aa9bd1b6fc 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -187,6 +187,7 @@ const translation = { editorTip: '앱 빌드만 가능하고 팀 설정 관리 불가능', inviteTeamMember: '팀 멤버 초대', inviteTeamMemberTip: '로그인 후에 바로 팀 데이터에 액세스할 수 있습니다.', + emailNotSetup: '이메일 서버가 설정되지 않아 초대 이메일을 보낼 수 없습니다. 대신 초대 후 발급되는 초대 링크를 사용자에게 알려주세요.', email: '이메일', emailInvalid: '유효하지 않은 이메일 형식', emailPlaceholder: '이메일 입력', diff --git a/web/i18n/pl-PL/common.ts b/web/i18n/pl-PL/common.ts index baaf5292c36f39..69441dbab38180 100644 --- a/web/i18n/pl-PL/common.ts +++ b/web/i18n/pl-PL/common.ts @@ -198,6 +198,7 @@ const translation = { inviteTeamMember: 'Dodaj członka zespołu', inviteTeamMemberTip: 'Mogą uzyskać bezpośredni dostęp do danych Twojego zespołu po zalogowaniu.', + emailNotSetup: 'Serwer poczty nie jest skonfigurowany, więc nie można wysyłać zaproszeń e-mail. Proszę powiadomić użytkowników o linku do zaproszenia, który zostanie wydany po zaproszeniu.', email: 'Email', emailInvalid: 'Nieprawidłowy format e-maila', emailPlaceholder: 'Proszę podać adresy e-mail', diff --git a/web/i18n/pt-BR/common.ts b/web/i18n/pt-BR/common.ts index a2c74a7bee2e25..6f66e65878b2af 100644 --- a/web/i18n/pt-BR/common.ts +++ b/web/i18n/pt-BR/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Pode editar aplicativos, mas não pode gerenciar configurações da equipe', inviteTeamMember: 'Adicionar membro da equipe', inviteTeamMemberTip: 'Eles podem acessar os dados da sua equipe diretamente após fazer login.', + emailNotSetup: 'O servidor de e-mail não está configurado, então os e-mails de convite não podem ser enviados. Por favor, notifique os usuários sobre o link de convite que será emitido após o convite.', email: 'E-mail', emailInvalid: 'Formato de e-mail inválido', emailPlaceholder: 'Por favor, insira e-mails', diff --git a/web/i18n/ro-RO/common.ts b/web/i18n/ro-RO/common.ts index 27a0ab6bf3c9d3..0badaf5a13fbe9 100644 --- a/web/i18n/ro-RO/common.ts +++ b/web/i18n/ro-RO/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Poate construi aplicații, dar nu poate gestiona setările echipei', inviteTeamMember: 'Adaugă membru în echipă', inviteTeamMemberTip: 'Pot accesa direct datele echipei dvs. după autentificare.', + emailNotSetup: 'Serverul de e-mail nu este configurat, astfel încât e-mailurile de invitație nu pot fi trimise. Vă rugăm să notificați utilizatorii despre linkul de invitație care va fi emis după invitație.', email: 'Email', emailInvalid: 'Format de email invalid', emailPlaceholder: 'Vă rugăm să introduceți emailuri', diff --git a/web/i18n/ru-RU/common.ts b/web/i18n/ru-RU/common.ts index 6d9edf97c17cc4..64a7c9375d4c0a 100644 --- a/web/i18n/ru-RU/common.ts +++ b/web/i18n/ru-RU/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Может управлять только базой знаний', inviteTeamMember: 'Добавить участника команды', inviteTeamMemberTip: 'Они могут получить доступ к данным вашей команды сразу после входа в систему.', + emailNotSetup: 'Почтовый сервер не настроен, поэтому приглашения по электронной почте не могут быть отправлены. Пожалуйста, уведомите пользователей о ссылке для приглашения, которая будет выдана после приглашения.', email: 'Электронная почта', emailInvalid: 'Неверный формат электронной почты', emailPlaceholder: 'Пожалуйста, введите адреса электронной почты', diff --git a/web/i18n/sl-SI/common.ts b/web/i18n/sl-SI/common.ts index dc399bd3a432a0..0c5d1dfc4b63f3 100644 --- a/web/i18n/sl-SI/common.ts +++ b/web/i18n/sl-SI/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Lahko upravlja samo bazo znanja', inviteTeamMember: 'Dodaj člana ekipe', inviteTeamMemberTip: 'Do vaših podatkov bo lahko dostopal takoj po prijavi.', + emailNotSetup: 'E-poštni strežnik ni nastavljen, zato vabil po e-pošti ni mogoče poslati. Prosimo, obvestite uporabnike o povezavi za povabilo, ki bo izdana po povabilu.', email: 'E-pošta', emailInvalid: 'Neveljaven format e-pošte', emailPlaceholder: 'Vnesite e-poštne naslove', diff --git a/web/i18n/th-TH/common.ts b/web/i18n/th-TH/common.ts index 82eddae723b667..6aa1b30610ec78 100644 --- a/web/i18n/th-TH/common.ts +++ b/web/i18n/th-TH/common.ts @@ -194,6 +194,7 @@ const translation = { datasetOperatorTip: 'สามารถจัดการฐานความรู้ได้เท่านั้น', inviteTeamMember: 'เพิ่มสมาชิกในทีม', inviteTeamMemberTip: 'พวกเขาสามารถเข้าถึงข้อมูลทีมของคุณได้โดยตรงหลังจากลงชื่อเข้าใช้', + emailNotSetup: 'เซิร์ฟเวอร์อีเมลไม่ได้ตั้งค่าไว้ จึงไม่สามารถส่งอีเมลเชิญได้ กรุณาแจ้งผู้ใช้เกี่ยวกับลิงก์เชิญที่จะออกหลังจากการเชิญแทน', email: 'อีเมล', emailInvalid: 'รูปแบบอีเมลไม่ถูกต้อง', emailPlaceholder: 'กรุณากรอกอีเมล', diff --git a/web/i18n/tr-TR/common.ts b/web/i18n/tr-TR/common.ts index 320517925a2f14..9792f07e18aece 100644 --- a/web/i18n/tr-TR/common.ts +++ b/web/i18n/tr-TR/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Sadece bilgi tabanını yönetebilir', inviteTeamMember: 'Takım Üyesi Ekle', inviteTeamMemberTip: 'Giriş yaptıktan sonra takım verilerinize doğrudan erişebilirler.', + emailNotSetup: 'E-posta sunucusu kurulu değil, bu nedenle davet e-postaları gönderilemiyor. Lütfen kullanıcıları davetten sonra verilecek davet bağlantısı hakkında bilgilendirin.', email: 'E-posta', emailInvalid: 'Geçersiz E-posta Formatı', emailPlaceholder: 'Lütfen e-postaları girin', diff --git a/web/i18n/uk-UA/common.ts b/web/i18n/uk-UA/common.ts index 6bf6dafc4fe265..fbe9b677509845 100644 --- a/web/i18n/uk-UA/common.ts +++ b/web/i18n/uk-UA/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Може створювати програми, але не може керувати налаштуваннями команди', inviteTeamMember: 'Додати учасника команди', inviteTeamMemberTip: 'Вони зможуть отримати доступ до даних вашої команди безпосередньо після входу.', + emailNotSetup: 'Поштовий сервер не налаштований, тому запрошення електронною поштою не можуть бути надіслані. Будь ласка, повідомте користувачів про посилання для запрошення, яке буде видано після запрошення.', email: 'Електронна пошта', emailInvalid: 'Недійсний формат електронної пошти', emailPlaceholder: 'Будь ласка, введіть адресу електронної пошти', diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts index bf5339f40ece94..8bafd868542d3d 100644 --- a/web/i18n/vi-VN/common.ts +++ b/web/i18n/vi-VN/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Có thể xây dựng ứng dụng, không thể quản lý cài đặt nhóm', inviteTeamMember: 'Mời thành viên nhóm', inviteTeamMemberTip: 'Sau khi đăng nhập, họ có thể truy cập trực tiếp vào dữ liệu nhóm của bạn.', + emailNotSetup: 'Máy chủ email chưa được thiết lập, vì vậy không thể gửi email mời. Vui lòng thông báo cho người dùng về liên kết mời sẽ được phát hành sau khi mời.', email: 'Email', emailInvalid: 'Định dạng Email không hợp lệ', emailPlaceholder: 'Vui lòng nhập email', diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index 70caaa976ae603..96e08a933756ba 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -197,6 +197,7 @@ const translation = { datasetOperatorTip: '只能管理知识库', inviteTeamMember: '添加团队成员', inviteTeamMemberTip: '对方在登录后可以访问你的团队数据。', + emailNotSetup: '由于邮件服务器未设置,无法发送邀请邮件。请将邀请后生成的邀请链接通知用户。', email: '邮箱', emailInvalid: '邮箱格式无效', emailPlaceholder: '输入邮箱', diff --git a/web/i18n/zh-Hant/common.ts b/web/i18n/zh-Hant/common.ts index 09c1e9d839d527..834065099321dd 100644 --- a/web/i18n/zh-Hant/common.ts +++ b/web/i18n/zh-Hant/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: '能夠建立並編輯應用程式,不能管理團隊設定', inviteTeamMember: '新增團隊成員', inviteTeamMemberTip: '對方在登入後可以訪問你的團隊資料。', + emailNotSetup: '由於郵件伺服器未設置,無法發送邀請郵件。請將邀請後生成的邀請連結通知用戶。', email: '郵箱', emailInvalid: '郵箱格式無效', emailPlaceholder: '輸入郵箱', From 8978a6a3ffba04db8aba8f5c8fd08c1c54f9a37c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 15:30:58 +0800 Subject: [PATCH 77/91] fix: remove unused credential validation logic in VectorizerProvider (#12000) Signed-off-by: -LAN- --- .../provider/builtin/vectorizer/vectorizer.py | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 211ec78f4d6a58..9d7613f8eaf170 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,32 +1,8 @@ from typing import Any -from core.file import FileTransferMethod, FileType -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from factories import file_factory class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - mapping = { - "transfer_method": FileTransferMethod.TOOL_FILE, - "type": FileType.IMAGE, - "id": "test_id", - "url": "https://cloud.dify.ai/logo/logo-site.png", - } - test_img = file_factory.build_from_mapping( - mapping=mapping, - tenant_id="__test_123", - ) - try: - VectorizerTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id="", - tool_parameters={"mode": "test", "image": test_img}, - ) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + return From 9cfd1c67b6c16ec199b1be740a91a47fd57791c6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 15:52:50 +0800 Subject: [PATCH 78/91] fix: Introduce ArrayVariable and update iteration node to handle it (#12001) Signed-off-by: -LAN- --- api/core/variables/__init__.py | 2 ++ api/core/variables/variables.py | 13 +++++++++---- .../workflow/nodes/iteration/iteration_node.py | 15 +++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 2b1a58f93aa3c6..7a1cbf99407ea8 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -21,6 +21,7 @@ ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + ArrayVariable, FileVariable, FloatVariable, IntegerVariable, @@ -43,6 +44,7 @@ "ArraySegment", "ArrayStringSegment", "ArrayStringVariable", + "ArrayVariable", "FileSegment", "FileVariable", "FloatSegment", diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index c902303eef54d4..f9268b52e6c490 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -10,6 +10,7 @@ ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, + ArraySegment, ArrayStringSegment, FileSegment, FloatSegment, @@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable): pass -class ArrayAnyVariable(ArrayAnySegment, Variable): +class ArrayVariable(ArraySegment, Variable): pass -class ArrayStringVariable(ArrayStringSegment, Variable): +class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): pass -class ArrayNumberVariable(ArrayNumberSegment, Variable): +class ArrayStringVariable(ArrayStringSegment, ArrayVariable): pass -class ArrayObjectVariable(ArrayObjectSegment, Variable): +class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): + pass + + +class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): pass diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index d935228c16abe9..6a89cbfad61684 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -9,7 +9,7 @@ from flask import Flask, current_app from configs import dify_config -from core.variables import IntegerVariable +from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.workflow.entities.node_entities import ( NodeRunMetadataKey, NodeRunResult, @@ -75,12 +75,15 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. """ - iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - if not iterator_list_segment: - raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") + if not variable: + raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - if len(iterator_list_segment.value) == 0: + if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): + raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") + + if isinstance(variable, NoneVariable) or len(variable.value) == 0: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -89,7 +92,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: ) return - iterator_list_value = iterator_list_segment.to_object() + iterator_list_value = variable.to_object() if not isinstance(iterator_list_value, list): raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") From e068bbec737ed649ba2294fb51b547bec6ea1c22 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 15:53:03 +0800 Subject: [PATCH 79/91] feat: add RequestBodyError for invalid request body handling (#11994) Signed-off-by: -LAN- --- api/core/workflow/nodes/http_request/exc.py | 4 ++++ api/core/workflow/nodes/http_request/executor.py | 9 +++++++++ api/core/workflow/nodes/http_request/node.py | 6 +++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py index 7a5ab7dbc1c1fa..a815f277becf9b 100644 --- a/api/core/workflow/nodes/http_request/exc.py +++ b/api/core/workflow/nodes/http_request/exc.py @@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError): class ResponseSizeError(HttpRequestNodeError): """Raised when the response size exceeds the allowed threshold.""" + + +class RequestBodyError(HttpRequestNodeError): + """Raised when the request body is invalid.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 3b7e19331993ee..575db15d365efb 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -23,6 +23,7 @@ FileFetchError, HttpRequestNodeError, InvalidHttpMethodError, + RequestBodyError, ResponseSizeError, ) @@ -143,13 +144,19 @@ def _init_body(self): case "none": self.content = "" case "raw-text": + if len(data) != 1: + raise RequestBodyError("raw-text body type should have exactly one item") self.content = self.variable_pool.convert_template(data[0].value).text case "json": + if len(data) != 1: + raise RequestBodyError("json body type should have exactly one item") json_string = self.variable_pool.convert_template(data[0].value).text json_object = json.loads(json_string, strict=False) self.json = json_object # self.json = self._parse_object_contains_variables(json_object) case "binary": + if len(data) != 1: + raise RequestBodyError("binary body type should have exactly one item") file_selector = data[0].file file_variable = self.variable_pool.get_file(file_selector) if file_variable is None: @@ -317,6 +324,8 @@ def to_log(self): elif self.json: body = json.dumps(self.json) elif self.node_data.body.type == "raw-text": + if len(self.node_data.body.data) != 1: + raise RequestBodyError("raw-text body type should have exactly one item") body = self.node_data.body.data[0].value if body: raw += f"Content-Length: {len(body)}\r\n" diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 171389a34c1c41..ebed690f6f3ffb 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -20,7 +20,7 @@ HttpRequestNodeTimeout, Response, ) -from .exc import HttpRequestNodeError +from .exc import HttpRequestNodeError, RequestBodyError HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -136,9 +136,13 @@ def _extract_variable_selector_to_variable_mapping( data = node_data.body.data match body_type: case "binary": + if len(data) != 1: + raise RequestBodyError("invalid body data, should have only one item") selector = data[0].file selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) case "json" | "raw-text": + if len(data) != 1: + raise RequestBodyError("invalid body data, should have only one item") selectors += variable_template_parser.extract_selectors_from_template(data[0].key) selectors += variable_template_parser.extract_selectors_from_template(data[0].value) case "x-www-form-urlencoded": From ef95b1268ebfbb9d9e77c4a62ebe97789b859020 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Mon, 23 Dec 2024 15:55:50 +0800 Subject: [PATCH 80/91] Fix/workflow retry (#11999) --- .../workflow/hooks/use-workflow-run.ts | 71 ++++++++++++++++--- .../_base/components/retry/retry-on-node.tsx | 4 +- web/app/components/workflow/run/index.tsx | 39 ++++++++-- .../workflow/run/iteration-result-panel.tsx | 40 ++++++++--- web/app/components/workflow/run/node.tsx | 5 ++ 5 files changed, 130 insertions(+), 29 deletions(-) diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 822aa490dbe4ec..cc1b0724a9184f 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -387,6 +387,9 @@ export const useWorkflowRun = () => { if (nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], + ...(currIteration[nodeIndex].retryDetail + ? { retryDetail: currIteration[nodeIndex].retryDetail } + : {}), ...data, } as any } @@ -626,6 +629,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -633,19 +638,65 @@ export const useWorkflowRun = () => { } = store.getState() const nodes = getNodes() - setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const tracing = draft.tracing! - const currentRetryNodeIndex = tracing.findIndex(trace => trace.node_id === data.node_id) + const currentNode = nodes.find(node => node.id === data.node_id)! + const nodeParent = nodes.find(node => node.id === currentNode.parentId) + if (nodeParent) { + if (!data.execution_metadata.parallel_mode_run_id) { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iteration = tracing.find(trace => trace.node_id === nodeParent.id) - if (currentRetryNodeIndex > -1) { - const currentRetryNode = tracing[currentRetryNodeIndex] - if (currentRetryNode.retryDetail) - draft.tracing![currentRetryNodeIndex].retryDetail!.push(data as NodeTracing) + if (iteration && iteration.details?.length) { + const currentNodeRetry = iteration.details[nodeParent.data._iterationIndex - 1]?.find(item => item.node_id === data.node_id) - else - draft.tracing![currentRetryNodeIndex].retryDetail = [data as NodeTracing] + if (currentNodeRetry) { + if (currentNodeRetry?.retryDetail) + currentNodeRetry?.retryDetail.push(data as NodeTracing) + else + currentNodeRetry.retryDetail = [data as NodeTracing] + } + } + })) } - })) + else { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iteration = tracing.find(trace => trace.node_id === nodeParent.id) + + if (iteration && iteration.details?.length) { + const iterRunID = data.execution_metadata?.parallel_mode_run_id + + const currIteration = iterParallelLogMap.get(iteration.node_id)?.get(iterRunID) + const currentNodeRetry = currIteration?.find(item => item.node_id === data.node_id) + + if (currentNodeRetry) { + if (currentNodeRetry?.retryDetail) + currentNodeRetry?.retryDetail.push(data as NodeTracing) + else + currentNodeRetry.retryDetail = [data as NodeTracing] + } + setIterParallelLogMap(iterParallelLogMap) + const iterLogMap = iterParallelLogMap.get(iteration.node_id) + if (iterLogMap) + iteration.details = Array.from(iterLogMap.values()) + } + })) + } + } + else { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const currentRetryNodeIndex = tracing.findIndex(trace => trace.node_id === data.node_id) + + if (currentRetryNodeIndex > -1) { + const currentRetryNode = tracing[currentRetryNodeIndex] + if (currentRetryNode.retryDetail) + draft.tracing![currentRetryNodeIndex].retryDetail!.push(data as NodeTracing) + else + draft.tracing![currentRetryNodeIndex].retryDetail = [data as NodeTracing] + } + })) + } const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx index f5d2f08ac8507f..9e1d8e1da6fce1 100644 --- a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx @@ -31,7 +31,7 @@ const RetryOnNode = ({ }, [data._runningStatus, showSelectedBorder]) const showDefault = !isRunning && !isSuccessful && !isException && !isFailed - if (!retry_config) + if (!retry_config?.retry_enabled) return null return ( @@ -74,7 +74,7 @@ const RetryOnNode = ({ }
{ - !showDefault && ( + !showDefault && !!data._retryIndex && (
{data._retryIndex}/{data.retry_config?.max_retries}
diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 520c59bf4c682d..8b0319cabe2ce9 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -78,11 +78,24 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const groupMap = nodeGroupMap.get(iterationNode.node_id)! - if (!groupMap.has(runId)) + if (!groupMap.has(runId)) { groupMap.set(runId, [item]) + } + else { + if (item.status === 'retry') { + const retryNode = groupMap.get(runId)!.find(node => node.node_id === item.node_id) - else - groupMap.get(runId)!.push(item) + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + } + else { + groupMap.get(runId)!.push(item) + } + } if (item.status === 'failed') { iterationNode.status = 'failed' @@ -94,10 +107,24 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { const { details } = iterationNode if (details) { - if (!details[index]) + if (!details[index]) { details[index] = [item] - else - details[index].push(item) + } + else { + if (item.status === 'retry') { + const retryNode = details[index].find(node => node.node_id === item.node_id) + + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + } + else { + details[index].push(item) + } + } } if (item.status === 'failed') { diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index b13eadec99bcb2..b809e1e669d42d 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -11,6 +11,7 @@ import { import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' import { NodeRunningStatus } from '../types' import TracingPanel from './tracing-panel' +import RetryResultPanel from './retry-result-panel' import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import type { IterationDurationMap, NodeTracing } from '@/types/workflow' @@ -41,8 +42,8 @@ const IterationResultPanel: FC = ({ })) }, []) const countIterDuration = (iteration: NodeTracing[], iterDurationMap: IterationDurationMap): string => { - const IterRunIndex = iteration[0].execution_metadata.iteration_index as number - const iterRunId = iteration[0].execution_metadata.parallel_mode_run_id + const IterRunIndex = iteration[0]?.execution_metadata?.iteration_index as number + const iterRunId = iteration[0]?.execution_metadata?.parallel_mode_run_id const iterItem = iterDurationMap[iterRunId || IterRunIndex] const duration = iterItem return `${(duration && duration > 0.01) ? duration.toFixed(2) : 0.01}s` @@ -74,6 +75,10 @@ const IterationResultPanel: FC = ({ ) } + const [retryRunResult, setRetryRunResult] = useState | undefined>() + const handleRetryDetail = (v: number, detail?: NodeTracing[]) => { + setRetryRunResult({ ...retryRunResult, [v]: detail }) + } const main = ( <> @@ -116,15 +121,28 @@ const IterationResultPanel: FC = ({ {expandedIterations[index] &&
} -
- -
+ { + !retryRunResult?.[index] && ( +
+ handleRetryDetail(index, v)} + /> +
+ ) + } + { + retryRunResult?.[index] && ( + handleRetryDetail(index, undefined)} + /> + ) + }
))}
diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index bb07bd1e8c61b8..d2da319a026405 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -216,6 +216,11 @@ const NodePanel: FC = ({ {nodeInfo.error} )} + {nodeInfo.status === 'retry' && ( + + {nodeInfo.error} + + )}
{nodeInfo.inputs && (
From c4091c4c6627e257765e15807e94763a63218db0 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 16:28:54 +0800 Subject: [PATCH 81/91] fix: improve error handling for file retrieval in AwsS3Storage (#12002) Signed-off-by: -LAN- --- api/extensions/storage/aws_s3_storage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index ab2d0fba3b19f3..ce36c2e7deeeda 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -67,7 +67,9 @@ def load_stream(self, filename: str) -> Generator: yield from response["Body"].iter_chunks() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") + raise FileNotFoundError("file not found") + elif "reached max retries" in str(ex): + raise ValueError("please do not request the same file too frequently") else: raise From dfc25dbdd036746d93e9cf6cf7ecd09317644abc Mon Sep 17 00:00:00 2001 From: yihong Date: Mon, 23 Dec 2024 16:30:04 +0800 Subject: [PATCH 82/91] fix: drop useless and wrong code in Account (#11961) Signed-off-by: yihong0618 --- api/models/account.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/api/models/account.py b/api/models/account.py index 932ba1da578b8a..a8602d10a97308 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -99,11 +99,6 @@ def get_by_openid(cls, provider: str, open_id: str) -> db.Model: return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return None - def get_integrates(self) -> list[db.Model]: - ai = db.Model - return db.session.query(ai).filter(ai.account_id == self.id).all() - - # check current_user.current_tenant.current_role in ['admin', 'owner'] @property def is_admin_or_owner(self): return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) From dc19cd5d9d1bfb354a522c86839e38501d973709 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Mon, 23 Dec 2024 16:42:28 +0800 Subject: [PATCH 83/91] fix: add retry feature to code node (#12005) Co-authored-by: Novice Lee --- api/core/workflow/nodes/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 32fdc048d133c0..7970a49aa42df4 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -35,4 +35,4 @@ class FailBranchSourceHandle(StrEnum): CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] -RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST] +RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE From 2bf33c4dd2862818e0a0d1e90ca8f85785a9cf28 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Mon, 23 Dec 2024 17:18:09 +0800 Subject: [PATCH 84/91] Fix/workflow retry log (#12013) --- .../workflow/nodes/_base/components/retry/retry-on-node.tsx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx index 9e1d8e1da6fce1..34c3e28d2c94e1 100644 --- a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx @@ -34,6 +34,9 @@ const RetryOnNode = ({ if (!retry_config?.retry_enabled) return null + if (!showDefault && !data._retryIndex) + return null + return (
Date: Mon, 23 Dec 2024 17:53:25 +0800 Subject: [PATCH 85/91] Fix/add retry mechanism to billing service request handling (#12006) Signed-off-by: -LAN- --- api/services/billing_service.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 911d2346415ce5..edc51682179cc5 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,6 +1,7 @@ import os -import requests +import httpx +from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed from extensions.ext_database import db from models.account import TenantAccountJoin, TenantAccountRole @@ -39,11 +40,17 @@ def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): return cls._send_request("GET", "/invoices", params=params) @classmethod + @retry( + wait=wait_fixed(2), + stop=stop_before_delay(10), + retry=retry_if_not_exception_type(httpx.RequestError), + reraise=True, + ) def _send_request(cls, method, endpoint, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers) + response = httpx.request(method, url, json=json, params=params, headers=headers) return response.json() From 75bce2822ed082ba6c3a1e2c78cd6e4b94b2001a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 17:53:36 +0800 Subject: [PATCH 86/91] fix: add logging for missing edge mapping in StreamProcessor (#12008) Signed-off-by: -LAN- --- api/core/workflow/nodes/answer/base_stream_processor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 36c3fe180a9cb2..031347662e441c 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from collections.abc import Generator @@ -5,6 +6,8 @@ from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph +logger = logging.getLogger(__name__) + class StreamProcessor(ABC): def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: @@ -31,6 +34,9 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: if run_result.edge_source_handle: reachable_node_ids = [] unreachable_first_node_ids = [] + if finished_node_id not in self.graph.edge_mapping: + logger.warning(f"node {finished_node_id} has no edge mapping") + return for edge in self.graph.edge_mapping[finished_node_id]: if ( edge.run_condition From d0dd8b79554144208d403578e297eb9f49bc139a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 17:53:42 +0800 Subject: [PATCH 87/91] fix: add UUID validation for tool file ID extraction (#12011) Signed-off-by: -LAN- --- api/core/workflow/nodes/tool/tool_node.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3b56f9487631c5..983fa7e623177a 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,6 @@ from collections.abc import Mapping, Sequence from typing import Any +from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import Session @@ -231,6 +232,10 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) url = str(response.message) transfer_method = FileTransferMethod.TOOL_FILE tool_file_id = url.split("/")[-1].split(".")[0] + try: + UUID(tool_file_id) + except ValueError: + raise ToolFileError(f"cannot extract tool file id from url {url}") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) From af2888d3947654668ff01f9c6606fcfa0d0f6381 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 17:53:57 +0800 Subject: [PATCH 88/91] fix: remove json_schema if response format is disabled. (#12014) Signed-off-by: -LAN- --- api/core/model_runtime/model_providers/openai/llm/llm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index b73ce8752f13f3..73cd7e3c341881 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -421,7 +421,11 @@ def _generate( # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs + prompt=prompt_messages[0].content, + model=model, + stream=stream, + **model_parameters, + **extra_model_kwargs, ) if stream: @@ -593,6 +597,8 @@ def _chat_generate( model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: model_parameters["response_format"] = {"type": response_format} + elif "json_schema" in model_parameters: + del model_parameters["json_schema"] extra_model_kwargs = {} From c3c85276d194d4de738214fb84d39bfa2c15bd3d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 17:54:08 +0800 Subject: [PATCH 89/91] Fix/refactor invoke result handling in question classifier node (#12015) Signed-off-by: -LAN- --- .../question_classifier_node.py | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 5043e25e2b8820..31f8368d590ea9 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,10 +1,8 @@ import json -import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.llm_generator.output_parser.errors import OutputParserError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole @@ -96,27 +94,28 @@ def _run(self): jinja2_variables=[], ) - # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - ) - result_text = "" usage = LLMUsage.empty_usage() finish_reason = None - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - category_name = node_data.classes[0].name - category_id = node_data.classes[0].id try: + # handle invoke result + generator = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + ) + + for event in generator: + if isinstance(event, ModelInvokeCompletedEvent): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + + category_name = node_data.classes[0].name + category_id = node_data.classes[0].id result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) if "category_name" in result_text_json and "category_id" in result_text_json: @@ -127,10 +126,6 @@ def _run(self): if category_id_result in category_ids: category_name = classes_map[category_id_result] category_id = category_id_result - - except OutputParserError: - logging.exception(f"Failed to parse result text: {result_text}") - try: process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -154,7 +149,7 @@ def _run(self): }, llm_usage=usage, ) - except Exception as e: + except ValueError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, From e0f1410b48a7330b6a3b85f2103e9e889e9ccde1 Mon Sep 17 00:00:00 2001 From: yihong Date: Mon, 23 Dec 2024 18:56:59 +0800 Subject: [PATCH 90/91] fix: issue Multiple Paths Between IF/ELSE Branches (#11646) Signed-off-by: yihong0618 --- api/core/workflow/nodes/answer/answer_stream_processor.py | 1 - api/core/workflow/nodes/answer/base_stream_processor.py | 8 +++++++- .../core/workflow/graph_engine/test_graph_engine.py | 2 -- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index d94f0590584842..ed033e7f283961 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -60,7 +60,6 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] - # remove unreachable nodes self._remove_unreachable_nodes(event) # generate stream outputs diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 031347662e441c..d785397e130565 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -43,7 +43,13 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: and edge.run_condition.branch_identify and run_result.edge_source_handle == edge.run_condition.branch_identify ): - reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + # remove unreachable nodes + # FIXME: because of the code branch can combine directly, so for answer node + # we remove the node maybe shortcut the answer node, so comment this code for now + # there is not effect on the answer node and the workflow, when we have a better solution + # we can open this code. Issues: #11542 #9560 #10638 #10564 + + # reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) continue else: unreachable_first_node_ids.append(edge.target_node_id) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 9f1ba7b6af9c80..b7d8f69e8c52ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -488,14 +488,12 @@ def test_run_branch(mock_close, mock_remove): items = [] generator = graph_engine.run() for item in generator: - # print(type(item), item) items.append(item) assert len(items) == 10 assert items[3].route_node_state.node_id == "if-else-1" assert items[4].route_node_state.node_id == "if-else-1" assert isinstance(items[5], NodeRunStreamChunkEvent) - assert items[5].chunk_content == "1 " assert isinstance(items[6], NodeRunStreamChunkEvent) assert items[6].chunk_content == "takato" assert items[7].route_node_state.node_id == "answer-1" From e88ea71aefd2adca64683ce0cca836df015d0f4f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 23 Dec 2024 19:15:48 +0800 Subject: [PATCH 91/91] chore/bump version to 0.14.2 (#12017) Signed-off-by: -LAN- --- api/configs/packaging/__init__.py | 2 +- api/services/app_dsl_service.py | 2 +- docker-legacy/docker-compose.yaml | 6 +++--- docker/docker-compose-template.yaml | 6 +++--- docker/docker-compose.yaml | 6 +++--- web/package.json | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 57cd74af1f0763..4a168a3fb13947 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.14.1", + default="0.14.2", ) COMMIT_SHA: str = Field( diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 0478903fa4d2d7..7c1a175988071e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -22,7 +22,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes -CURRENT_DSL_VERSION = "0.1.4" +CURRENT_DSL_VERSION = "0.1.5" class ImportMode(StrEnum): diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 3bf4333ad17edb..1cff58be7f661d 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.14.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.14.1 + image: langgenius/dify-api:0.14.2 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.14.1 + image: langgenius/dify-web:0.14.2 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 8370d82daa6cdf..d4e0ba49d0ba9a 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.14.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. @@ -25,7 +25,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.14.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. @@ -47,7 +47,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.14.1 + image: langgenius/dify-web:0.14.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 99bc14c7171e4b..7122f4a6d0f768 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -390,7 +390,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.14.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. @@ -413,7 +413,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.14.1 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. @@ -435,7 +435,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.14.1 + image: langgenius/dify-web:0.14.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index 25cf231a95ed63..d9515645c8b251 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.14.1", + "version": "0.14.2", "private": true, "engines": { "node": ">=18.17.0"