Skip to content

Commit

Permalink
Merge branch 'feat/support-parent-child-chunk' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong committed Dec 6, 2024
2 parents e90c41f + 4f3a976 commit 8396eeb
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 23 deletions.
2 changes: 1 addition & 1 deletion api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,4 +759,4 @@ def get(self, dataset_id):
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")
19 changes: 19 additions & 0 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,25 @@ def post(self, dataset_id):

return {"documents": documents, "batch": batch}

@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)

try:
document_ids = request.args.getlist("document_id")
DatasetService.delete_documents(dataset, document_ids)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")

return {"result": "success"}, 204


class DatasetInitApi(Resource):
@setup_required
Expand Down
5 changes: 3 additions & 2 deletions api/controllers/console/datasets/datasets_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
Expand Down Expand Up @@ -639,7 +638,9 @@ def patch(self, dataset_id, document_id, segment_id, child_chunk_id):


api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
Expand Down
6 changes: 3 additions & 3 deletions api/core/rag/embedding/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""

id: str
content: str
score: float


class RetrievalSegments(BaseModel):
"""Retrieval segments."""
model_config = {
"arbitrary_types_allowed": True
}

model_config = {"arbitrary_types_allowed": True}
segment: DocumentSegment
child_chunks: Optional[list[RetrievalChildChunk]] = None
score: Optional[float] = None
28 changes: 18 additions & 10 deletions api/schedule/clean_unused_datasets_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,15 @@ def clean_unused_datasets_task():
if not dataset_query or len(dataset_query) == 0:
try:
# add auto disable log
documents = db.session.query(Document).filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
).all()
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
Expand Down Expand Up @@ -167,11 +171,15 @@ def clean_unused_datasets_task():
plan = plan_cache.decode()
if plan == "sandbox":
# add auto disable log
documents = db.session.query(Document).filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
).all()
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
Expand Down
32 changes: 25 additions & 7 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
Expand Down Expand Up @@ -405,7 +406,7 @@ def get_related_apps(dataset_id: str):
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)

@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
# get recent 30 days auto disable logs
Expand Down Expand Up @@ -604,6 +605,20 @@ def delete_document(document):
db.session.delete(document)
db.session.commit()

@staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]):
documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
if document.data_source_type == "upload_file"
]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)

for document in documents:
db.session.delete(document)
db.session.commit()

@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
Expand Down Expand Up @@ -1903,12 +1918,15 @@ def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset):
db.session.commit()

@classmethod
def get_child_chunks(cls, segment_id: str, document_id: str, dataset_id: str,
page: int, limit: int, keyword: Optional[str] = None):
query = ChildChunk.query.filter_by(tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id)
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
query = ChildChunk.query.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
Expand Down

0 comments on commit 8396eeb

Please sign in to comment.