Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Jun 10, 2024
1 parent af60d5c commit 277aa09
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 24 deletions.
2 changes: 1 addition & 1 deletion api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.retrieval_service import RetrievalMethod
from core.rag.datasource.retrival_methods import RetrievalMethod
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
Expand Down
17 changes: 1 addition & 16 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
import threading
from enum import Enum
from typing import Optional

from flask import Flask, current_app

from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrival_methods import RetrievalMethod
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset


class RetrievalMethod(str, Enum):
SEMANTIC_SEARCH = 'semantic_search'
FULL_TEXT_SEARCH = 'full_text_search'
HYBRID_SEARCH = 'hybrid_search'

@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH}

@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH}


default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'reranking_enable': False,
Expand Down
15 changes: 15 additions & 0 deletions api/core/rag/datasource/retrival_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import Enum


class RetrievalMethod(str, Enum):
SEMANTIC_SEARCH = 'semantic_search'
FULL_TEXT_SEARCH = 'full_text_search'
HYBRID_SEARCH = 'hybrid_search'

@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH}

@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH}
3 changes: 2 additions & 1 deletion api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalMethod, RetrievalService
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrival_methods import RetrievalMethod
from core.rag.models.document import Document
from core.rag.rerank.rerank import RerankRunner
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalMethod, RetrievalService
from core.rag.datasource.retrival_methods import RetrievalMethod, RetrievalService
from core.rag.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from pydantic import BaseModel, Field

from core.rag.datasource.retrieval_service import RetrievalMethod, RetrievalService
from core.rag.datasource.retrival_methods import RetrievalMethod, RetrievalService
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalMethod
from core.rag.datasource.retrival_methods import RetrievalMethod
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
Expand Down
2 changes: 1 addition & 1 deletion api/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB

from core.rag.datasource.retrieval_service import RetrievalMethod
from core.rag.datasource.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import StringUUID
Expand Down
2 changes: 1 addition & 1 deletion api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalMethod
from core.rag.datasource.retrival_methods import RetrievalMethod
from core.rag.models.document import Document as RAGDocument
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
Expand Down
3 changes: 2 additions & 1 deletion api/services/hit_testing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.retrieval_service import RetrievalMethod, RetrievalService
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrival_methods import RetrievalMethod
from core.rag.models.document import Document
from extensions.ext_database import db
from models.account import Account
Expand Down

0 comments on commit 277aa09

Please sign in to comment.