Skip to content

Commit

Permalink
fix wrong using of RetrievalMethod Enum (#6345)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Jul 16, 2024
1 parent ed9e692 commit 0de224b
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 24 deletions.
16 changes: 8 additions & 8 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,15 +545,15 @@ def get(self):
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
Expand All @@ -569,15 +569,15 @@ def get(self, vector_type):
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
Expand Down
8 changes: 4 additions & 4 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from models.dataset import Dataset

default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down Expand Up @@ -86,7 +86,7 @@ def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
exception_message = ';\n'.join(exceptions)
raise Exception(exception_message)

if retrival_method == RetrievalMethod.HYBRID_SEARCH:
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
Expand Down Expand Up @@ -142,7 +142,7 @@ def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
)

if documents:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents.extend(data_post_processor.invoke(
query=query,
Expand Down Expand Up @@ -174,7 +174,7 @@ def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k=top_k
)
if documents:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents.extend(data_post_processor.invoke(
query=query,
Expand Down
4 changes: 2 additions & 2 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from models.dataset import Document as DatasetDocument

default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down Expand Up @@ -464,7 +464,7 @@ def to_dataset_retriever_tool(self, tenant_id: str,
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down
6 changes: 3 additions & 3 deletions api/core/rag/retrieval/retrival_methods.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from enum import Enum


class RetrievalMethod(str, Enum):
class RetrievalMethod(Enum):
SEMANTIC_SEARCH = 'semantic_search'
FULL_TEXT_SEARCH = 'full_text_search'
HYBRID_SEARCH = 'hybrid_search'

@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH}
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}

@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH}
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from models.dataset import Dataset, Document, DocumentSegment

default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from models.dataset import Dataset, Document, DocumentSegment

default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from models.workflow import WorkflowNodeExecutionStatus

default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down
2 changes: 1 addition & 1 deletion api/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def doc_form(self):
@property
def retrieval_model_dict(self):
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down
4 changes: 2 additions & 2 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def save_document_with_dataset_id(
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down Expand Up @@ -1059,7 +1059,7 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun
retrieval_model = document_data['retrieval_model']
else:
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down
2 changes: 1 addition & 1 deletion api/services/hit_testing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from models.dataset import Dataset, DatasetQuery, DocumentSegment

default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
Expand Down

0 comments on commit 0de224b

Please sign in to comment.