From f9c2aa768927d2a88f292351e40ea5bf2640cf8a Mon Sep 17 00:00:00 2001 From: Cling_o3 <45124798+ProseGuys@users.noreply.github.com> Date: Sat, 30 Nov 2024 11:14:45 +0800 Subject: [PATCH] feat: add retireval_top_n to config in env (#11132) --- api/.env.example | 4 +++- api/configs/feature/__init__.py | 2 ++ api/core/rag/datasource/retrieval_service.py | 17 ++++++++++++++--- docker/docker-compose.yaml | 1 + 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/api/.env.example b/api/.env.example index 562ed970e06245..40f35cff6c66f6 100644 --- a/api/.env.example +++ b/api/.env.example @@ -411,4 +411,6 @@ POSITION_PROVIDER_EXCLUDES= # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -CREATE_TIDB_SERVICE_JOB_ENABLED=false \ No newline at end of file +CREATE_TIDB_SERVICE_JOB_ENABLED=false + +RETRIEVAL_TOP_N=0 \ No newline at end of file diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 572f91ad2e7dd9..275b16a9131add 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -626,6 +626,8 @@ class DataSetConfig(BaseSettings): default=30, ) + RETRIEVAL_TOP_N: int = Field(description="number of retrieval top_n", default=0) + class WorkspaceConfig(BaseSettings): """ diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 57af05861c1ad0..759166da4c4be8 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,6 +3,7 @@ from flask import Flask, current_app +from configs import DifyConfig from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -110,8 +111,12 @@ def retrieve( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or top_k, ) + return all_documents @classmethod @@ -178,7 +183,10 @@ def embedding_search( ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents), ) ) else: @@ -220,7 +228,10 @@ def full_text_index_search( ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents), ) ) else: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3e2b276c92d66e..aa3c810d3faff0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -287,6 +287,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} + RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-0} services: # API service