From c26ed3ec63c4a10c035032278879a1e7ae7ef1b6 Mon Sep 17 00:00:00 2001 From: ProseGuys <798770222@qq.com> Date: Tue, 26 Nov 2024 17:26:25 +0800 Subject: [PATCH 1/3] feat: add retireval_top_n to config in env --- api/.env.example | 4 +++- api/configs/feature/__init__.py | 5 +++++ api/core/rag/datasource/retrieval_service.py | 18 +++++++++++++++--- docker/docker-compose.yaml | 1 + 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/api/.env.example b/api/.env.example index f8a281256380c2..e12c76a229f16a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -410,4 +410,6 @@ POSITION_PROVIDER_EXCLUDES= # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -CREATE_TIDB_SERVICE_JOB_ENABLED=false \ No newline at end of file +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +RETRIEVAL_TOP_N= diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 99f86be12ebbb4..f48c074ac8c1cc 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -619,6 +619,11 @@ class DataSetConfig(BaseSettings): PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field( description="Interval in days for message cleanup operations - plan: sandbox", default=30, + ) + + RETRIEVAL_TOP_N: Optional[PositiveInt] = Field( + description="number of retrieval top_n", + default=None ) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 57af05861c1ad0..dbd8fb7cd76073 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,6 +3,7 @@ from flask import Flask, current_app +from configs import DifyConfig from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -105,13 +106,18 @@ def retrieve( exception_message = ";\n".join(exceptions) raise Exception(exception_message) + if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else top_k ) + return all_documents @classmethod @@ -178,7 +184,10 @@ def embedding_search( ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) ) ) else: @@ -220,7 +229,10 @@ def full_text_index_search( ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) ) ) else: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3e2b276c92d66e..a9564976fc67ce 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -287,6 +287,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} + RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-} services: # API service From 18467930efc43f3f802797a50c81dbaee4ddb99e Mon Sep 17 00:00:00 2001 From: ProseGuys <798770222@qq.com> Date: Tue, 26 Nov 2024 17:59:19 +0800 Subject: [PATCH 2/3] style: polish code --- api/configs/feature/__init__.py | 7 ++---- api/core/rag/datasource/retrieval_service.py | 25 ++++++++++---------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f48c074ac8c1cc..ce9f8f2725dce9 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -619,13 +619,10 @@ class DataSetConfig(BaseSettings): PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field( description="Interval in days for message cleanup operations - plan: sandbox", default=30, - ) - - RETRIEVAL_TOP_N: Optional[PositiveInt] = Field( - description="number of retrieval top_n", - default=None ) + RETRIEVAL_TOP_N: Optional[PositiveInt] = Field(description="number of retrieval top_n", default=None) + class WorkspaceConfig(BaseSettings): """ diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index dbd8fb7cd76073..759166da4c4be8 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -106,16 +106,15 @@ def retrieve( exception_message = ";\n".join(exceptions) raise Exception(exception_message) - if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else top_k + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or top_k, ) return all_documents @@ -184,10 +183,10 @@ def embedding_search( ) all_documents.extend( data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents), ) ) else: @@ -229,10 +228,10 @@ def full_text_index_search( ) all_documents.extend( data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents), ) ) else: From a07a5e1cd44328cb5d09299f155518ddaf110788 Mon Sep 17 00:00:00 2001 From: ProseGuys <798770222@qq.com> Date: Wed, 27 Nov 2024 09:29:34 +0800 Subject: [PATCH 3/3] fix: adjust env's params --- api/.env.example | 2 +- api/configs/feature/__init__.py | 2 +- docker/docker-compose.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/.env.example b/api/.env.example index e12c76a229f16a..95c64f48524a5a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -412,4 +412,4 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 CREATE_TIDB_SERVICE_JOB_ENABLED=false -RETRIEVAL_TOP_N= +RETRIEVAL_TOP_N=0 \ No newline at end of file diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ce9f8f2725dce9..2fdf151e038e1e 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -621,7 +621,7 @@ class DataSetConfig(BaseSettings): default=30, ) - RETRIEVAL_TOP_N: Optional[PositiveInt] = Field(description="number of retrieval top_n", default=None) + RETRIEVAL_TOP_N: int = Field(description="number of retrieval top_n", default=0) class WorkspaceConfig(BaseSettings): diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a9564976fc67ce..aa3c810d3faff0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -287,7 +287,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} - RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-} + RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-0} services: # API service