From 18467930efc43f3f802797a50c81dbaee4ddb99e Mon Sep 17 00:00:00 2001 From: ProseGuys <798770222@qq.com> Date: Tue, 26 Nov 2024 17:59:19 +0800 Subject: [PATCH] style: polish code --- api/configs/feature/__init__.py | 7 ++---- api/core/rag/datasource/retrieval_service.py | 25 ++++++++++---------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f48c074ac8c1cc..ce9f8f2725dce9 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -619,13 +619,10 @@ class DataSetConfig(BaseSettings): PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field( description="Interval in days for message cleanup operations - plan: sandbox", default=30, - ) - - RETRIEVAL_TOP_N: Optional[PositiveInt] = Field( - description="number of retrieval top_n", - default=None ) + RETRIEVAL_TOP_N: Optional[PositiveInt] = Field(description="number of retrieval top_n", default=None) + class WorkspaceConfig(BaseSettings): """ diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index dbd8fb7cd76073..759166da4c4be8 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -106,16 +106,15 @@ def retrieve( exception_message = ";\n".join(exceptions) raise Exception(exception_message) - if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else top_k + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or top_k, ) return all_documents @@ -184,10 +183,10 @@ def embedding_search( ) all_documents.extend( data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents), ) ) else: @@ -229,10 +228,10 @@ def full_text_index_search( ) all_documents.extend( data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=DifyConfig.RETRIEVAL_TOP_N if DifyConfig.RETRIEVAL_TOP_N else len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents), ) ) else: