Skip to content

Commit

Permalink
Merge branch 'feat/support-parent-child-chunk' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong committed Dec 24, 2024
2 parents d4b0eb4 + 71950ee commit 37808a3
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 47 deletions.
29 changes: 15 additions & 14 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def indexing_estimate(
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []

total_segments = 0
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
Expand All @@ -289,14 +290,21 @@ def indexing_estimate(
process_rule=processing_rule.to_dict(),
tenant_id=current_user.current_tenant_id,
doc_language=doc_language,
preview=True
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 10:
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(question=document.page_content,
answer=document.metadata.get("answer")
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)

# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
Expand All @@ -312,16 +320,9 @@ def indexing_estimate(
db.session.delete(image_file)

if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(
current_user.current_tenant_id, preview_texts[0].content, doc_language
)
document_qa_list = self.format_split_text(response)

return IndexingEstimate(
total_segments=total_segments * 20, qa_preview=document_qa_list, preview=preview_texts
)
return IndexingEstimate(
total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]
)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)

def _extract(
Expand Down
6 changes: 1 addition & 5 deletions api/core/rag/extractor/extract_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,7 @@ def extract(
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
else:
# txt
extractor = (
UnstructuredTextExtractor(file_path, unstructured_api_url)
if is_automatic
else TextExtractor(file_path, autodetect_encoding=True)
)
extractor = TextExtractor(file_path, autodetect_encoding=True)
else:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
Expand Down
44 changes: 26 additions & 18 deletions api/core/rag/index_processor/processor/qa_index_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
return text_docs

def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter(
Expand Down Expand Up @@ -67,24 +68,31 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]:
document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
if preview:
self._format_qa_document(current_app._get_current_object(),
kwargs.get("tenant_id"),
all_documents[0],
all_qa_documents,
kwargs.get("doc_language", "English"))
else:
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents

def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
Expand Down
16 changes: 8 additions & 8 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def save_document_with_dataset_id(
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
Expand All @@ -897,13 +897,13 @@ def save_document_with_dataset_id(
).first()
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info["pages"]:
if page["page_id"] not in exist_page_ids:
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page["page_id"],
"notion_page_icon": page["page_icon"],
"type": page["type"],
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon,
"type": page.type,
}
document = DocumentService.build_document(
dataset,
Expand All @@ -915,7 +915,7 @@ def save_document_with_dataset_id(
created_from,
position,
account,
page["page_name"],
page.page_name,
batch,
)
db.session.add(document)
Expand All @@ -924,7 +924,7 @@ def save_document_with_dataset_id(
documents.append(document)
position += 1
else:
exist_document.pop(page["page_id"])
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class ParentMode(str, Enum):
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: str
page_icon: Optional[str]
type: str


class NotionInfo(BaseModel):
Expand All @@ -31,7 +32,7 @@ class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool
only_main_content: bool = True


class FileInfo(BaseModel):
Expand Down

0 comments on commit 37808a3

Please sign in to comment.