diff --git a/api/.env.example b/api/.env.example index f8a281256380c2..e12c76a229f16a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -410,4 +410,6 @@ POSITION_PROVIDER_EXCLUDES= # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -CREATE_TIDB_SERVICE_JOB_ENABLED=false \ No newline at end of file +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +RETRIEVAL_TOP_N= diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 99f86be12ebbb4..f48c074ac8c1cc 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -619,6 +619,11 @@ class DataSetConfig(BaseSettings): PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field( description="Interval in days for message cleanup operations - plan: sandbox", default=30, + ) + + RETRIEVAL_TOP_N: Optional[PositiveInt] = Field( + description="number of retrieval top_n", + default=None ) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 57af05861c1ad0..dbd8fb7cd76073 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,6 +3,7 @@ from flask import Flask, current_app +from configs import DifyConfig from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -105,13 +106,18 @@ def retrieve( exception_message = ";\n".join(exceptions) raise Exception(exception_message) + if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else top_k ) + return all_documents @classmethod @@ -178,7 +184,10 @@ def embedding_search( ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) ) ) else: @@ -220,7 +229,10 @@ def full_text_index_search( ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) ) ) else: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3e2b276c92d66e..a9564976fc67ce 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -287,6 +287,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} + RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-} services: # API service