diff --git a/minds/knowledge_bases/knowledge_bases.py b/minds/knowledge_bases/knowledge_bases.py index 58b7f96..337a5b4 100644 --- a/minds/knowledge_bases/knowledge_bases.py +++ b/minds/knowledge_bases/knowledge_bases.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from minds.knowledge_bases.preprocessing import PreprocessingConfig from minds.rest_api import RestAPI @@ -25,6 +26,8 @@ class KnowledgeBaseConfig(BaseModel): description: str vector_store_config: Optional[VectorStoreConfig] = None embedding_config: Optional[EmbeddingConfig] = None + # Params to apply to retrieval pipeline. + params: Optional[Dict] = None class KnowledgeBaseDocument(BaseModel): @@ -39,7 +42,7 @@ def __init__(self, name, api: RestAPI): self.name = name self.api = api - def insert_from_select(self, query: str): + def insert_from_select(self, query: str, preprocessing_config: PreprocessingConfig = None): ''' Inserts select content of a connected datasource into this knowledge base @@ -48,9 +51,11 @@ def insert_from_select(self, query: str): update_request = { 'query': query } + if preprocessing_config is not None: + update_request['preprocessing'] = preprocessing_config.model_dump() _ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request) - def insert_documents(self, documents: List[KnowledgeBaseDocument]): + def insert_documents(self, documents: List[KnowledgeBaseDocument], preprocessing_config: PreprocessingConfig = None): ''' Inserts documents directly into this knowledge base @@ -59,9 +64,11 @@ def insert_documents(self, documents: List[KnowledgeBaseDocument]): update_request = { 'rows': [d.model_dump() for d in documents] } + if preprocessing_config is not None: + update_request['preprocessing'] = preprocessing_config.model_dump() _ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request) - def insert_urls(self, urls: List[str]): + def insert_urls(self, urls: List[str], preprocessing_config: PreprocessingConfig = None): ''' Crawls URLs & inserts the retrieved webpages into this knowledge base @@ -70,9 +77,11 @@ def insert_urls(self, urls: List[str]): update_request = { 'urls': urls } + if preprocessing_config is not None: + update_request['preprocessing'] = preprocessing_config.model_dump() _ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request) - def insert_files(self, files: List[str]): + def insert_files(self, files: List[str], preprocessing_config: PreprocessingConfig = None): ''' Inserts files that have already been uploaded to MindsDB into this knowledge base @@ -81,6 +90,8 @@ def insert_files(self, files: List[str]): update_request = { 'files': files } + if preprocessing_config is not None: + update_request['preprocessing'] = preprocessing_config.model_dump() _ = self.api.put(f'/knowledge_bases/{self.name}', data=update_request) @@ -117,6 +128,8 @@ def create(self, config: KnowledgeBaseConfig) -> KnowledgeBase: if config.embedding_config.params is not None: embedding_data.update(config.embedding_config.params) create_request['embedding_model'] = embedding_data + if config.params is not None: + create_request['params'] = config.params _ = self.api.post('/knowledge_bases', data=create_request) return self.get(config.name) diff --git a/minds/knowledge_bases/preprocessing.py b/minds/knowledge_bases/preprocessing.py new file mode 100644 index 0000000..7ea4716 --- /dev/null +++ b/minds/knowledge_bases/preprocessing.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field, model_validator + + +DEFAULT_LLM_MODEL = 'gpt-4o' +DEFAULT_LLM_MODEL_PROVIDER = 'openai' + + +class TextChunkingConfig(BaseModel): + '''Configuration for chunking text content before they are inserted into a knowledge base''' + separators: List[str] = Field( + default=['\n\n', '\n', ' ', ''], + description='List of separators to use for splitting text, in order of priority' + ) + chunk_size: int = Field( + default=1000, + description='The target size of each text chunk', + gt=0 + ) + chunk_overlap: int = Field( + default=200, + description='The number of characters to overlap between chunks', + ge=0 + ) + + +class LLMConfig(BaseModel): + model_name: str = Field(default=DEFAULT_LLM_MODEL, description='LLM model to use for context generation') + provider: str = Field(default=DEFAULT_LLM_MODEL_PROVIDER, description='LLM model provider to use for context generation') + params: Dict[str, Any] = Field(default={}, description='Additional parameters to pass in when initializing the LLM') + + +class ContextualConfig(BaseModel): + '''Configuration specific to contextual preprocessing''' + llm_config: LLMConfig = Field( + default=LLMConfig(), + description='LLM configuration to use for context generation' + ) + context_template: Optional[str] = Field( + default=None, + description='Custom template for context generation' + ) + chunk_size: int = Field( + default=1000, + description='The target size of each text chunk', + gt=0 + ) + chunk_overlap: int = Field( + default=200, + description='The number of characters to overlap between chunks', + ge=0 + ) + + +class PreprocessingConfig(BaseModel): + '''Complete preprocessing configuration''' + type: Literal['contextual', 'text_chunking'] = Field( + default='text_chunking', + description='Type of preprocessing to apply' + ) + contextual_config: Optional[ContextualConfig] = Field( + default=None, + description='Configuration for contextual preprocessing' + ) + text_chunking_config: Optional[TextChunkingConfig] = Field( + default=None, + description='Configuration for text chunking preprocessing' + ) + + @model_validator(mode='after') + def validate_config_presence(self) -> 'PreprocessingConfig': + '''Ensure the appropriate config is present for the chosen type''' + if self.type == 'contextual' and not self.contextual_config: + self.contextual_config = ContextualConfig() + if self.type == 'text_chunking' and not self.text_chunking_config: + self.text_chunking_config = TextChunkingConfig() + return self diff --git a/tests/unit/test_unit.py b/tests/unit/test_unit.py index 4580e4b..daeac79 100644 --- a/tests/unit/test_unit.py +++ b/tests/unit/test_unit.py @@ -130,7 +130,10 @@ def test_create_knowledge_bases(self, mock_post, mock_get): name='test_kb', description='Test knowledge base', vector_store_config=test_vector_store_config, - embedding_config=test_embedding_config + embedding_config=test_embedding_config, + params={ + 'k1': 'v1' + } ) response_mock(mock_get, test_knowledge_base_config.model_dump()) @@ -152,6 +155,9 @@ def test_create_knowledge_bases(self, mock_post, mock_get): 'provider': test_embedding_config.provider, 'name': test_embedding_config.model, 'k1': 'v1' + }, + 'params': { + 'k1': 'v1' } }