Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add support to update record fields #5685

Merged
merged 18 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/argilla.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
build:
services:
argilla-server:
image: argilladev/argilla-hf-spaces:develop
image: argilladev/argilla-hf-spaces:pr-5685
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
ports:
- 6900:6900
env:
Expand Down
5 changes: 5 additions & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ These are the section headers that we use:

## [Unreleased]()

### Added

- Added support to update record fields in `PATCH /api/v1/records/:record_id` endpoint. ([#5685](https://github.com/argilla-io/argilla/pull/5685))
- Added support to update record fields in `PUT /api/v1/datasets/:dataset_id/records/bulk` endpoint. ([#5685](https://github.com/argilla-io/argilla/pull/5685))

### Changed

- Changed default python version to 3.13. ([#5649](https://github.com/argilla-io/argilla/pull/5649))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def delete_dataset_records(
if num_records > DELETE_DATASET_RECORDS_LIMIT:
raise UnprocessableEntityError(f"Cannot delete more than {DELETE_DATASET_RECORDS_LIMIT} records at once")

await datasets.delete_records(db, search_engine, dataset, record_ids)
await records.delete_records(db, search_engine, dataset, record_ids)


@router.post(
Expand Down
15 changes: 10 additions & 5 deletions argilla-server/src/argilla_server/api/handlers/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from argilla_server.api.schemas.v1.responses import Response, ResponseCreate
from argilla_server.api.schemas.v1.suggestions import Suggestion as SuggestionSchema
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate, Suggestions
from argilla_server.contexts import datasets
from argilla_server.contexts import datasets, records
from argilla_server.database import get_async_db
from argilla_server.errors.future.base_errors import NotFoundError, UnprocessableEntityError
from argilla_server.models import Dataset, Question, Record, Suggestion, User
Expand Down Expand Up @@ -74,16 +74,21 @@ async def update_record(
db,
record_id,
options=[
selectinload(Record.dataset).selectinload(Dataset.questions),
selectinload(Record.dataset).selectinload(Dataset.metadata_properties),
selectinload(Record.dataset).options(
selectinload(Dataset.questions),
selectinload(Dataset.metadata_properties),
selectinload(Dataset.vectors_settings),
selectinload(Dataset.fields),
),
selectinload(Record.suggestions),
selectinload(Record.responses),
selectinload(Record.vectors),
],
)

await authorize(current_user, RecordPolicy.update(record))

return await datasets.update_record(db, search_engine, record, record_update)
return await records.update_record(db, search_engine, record, record_update)


@router.post("/records/{record_id}/responses", status_code=status.HTTP_201_CREATED, response_model=Response)
Expand Down Expand Up @@ -233,4 +238,4 @@ async def delete_record(

await authorize(current_user, RecordPolicy.delete(record))

return await datasets.delete_record(db, search_engine, record)
return await records.delete_record(db, search_engine, record)
26 changes: 10 additions & 16 deletions argilla-server/src/argilla_server/api/schemas/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from uuid import UUID


from argilla_server.api.schemas.v1.chat import ChatFieldValue
from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName
Expand Down Expand Up @@ -151,17 +150,12 @@ def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict


class RecordUpdate(UpdateSchema):
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")
fields: Optional[Dict[str, FieldValueCreate]] = None
metadata: Optional[Dict[str, Any]] = None
suggestions: Optional[List[SuggestionCreate]] = None
vectors: Optional[Dict[str, List[float]]]

@property
def metadata(self) -> Optional[Dict[str, Any]]:
# Align with the RecordCreate model. Both should have the same name for the metadata field.
# TODO(@frascuchon): This will be properly adapted once the bulk records refactor is completed.
return self.metadata_

@validator("metadata_")
@validator("metadata")
@classmethod
def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if metadata is None:
Expand All @@ -173,15 +167,20 @@ def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict

return {k: v for k, v in metadata.items() if v == v} # By definition, NaN != NaN

def is_set(self, attribute: str) -> bool:
return attribute in self.__fields_set__

class RecordUpdateWithId(RecordUpdate):
id: UUID
def has_changes(self) -> bool:
return self.dict(exclude_unset=True) != {}
frascuchon marked this conversation as resolved.
Show resolved Hide resolved


class RecordUpsert(RecordCreate):
id: Optional[UUID]
fields: Optional[Dict[str, FieldValueCreate]] = None

def is_set(self, attribute: str) -> bool:
return attribute in self.__fields_set__


class RecordIncludeParam(BaseModel):
relationships: Optional[List[RecordInclude]] = Field(None, alias="keys")
Expand Down Expand Up @@ -245,11 +244,6 @@ class RecordsCreate(BaseModel):
items: List[RecordCreate] = Field(..., min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS)


class RecordsUpdate(BaseModel):
# TODO: review this definition and align to create model
items: List[RecordUpdateWithId] = Field(..., min_items=RECORDS_UPDATE_MIN_ITEMS, max_items=RECORDS_UPDATE_MAX_ITEMS)


class MetadataParsedQueryParam:
def __init__(self, string: str):
k, *v = string.split(":", maxsplit=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RecordsBulk(BaseModel):
items: List[Record]


class RecordsBulkWithUpdateInfo(RecordsBulk):
class RecordsBulkWithUpdatedItemIds(RecordsBulk):
updated_item_ids: List[UUID]


Expand Down
24 changes: 13 additions & 11 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, Sequence, Tuple, Union
from uuid import UUID

from datetime import UTC
from fastapi.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -26,7 +27,7 @@
RecordsBulk,
RecordsBulkCreate,
RecordsBulkUpsert,
RecordsBulkWithUpdateInfo,
RecordsBulkWithUpdatedItemIds,
)
from argilla_server.api.schemas.v1.responses import UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
Expand All @@ -36,7 +37,7 @@
fetch_records_by_ids_as_dict,
)
from argilla_server.errors.future import UnprocessableEntityError
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector, VectorSettings
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.records import RecordsBulkCreateValidator, RecordUpsertValidator

Expand Down Expand Up @@ -139,15 +140,11 @@ async def _upsert_records_vectors(
autocommit=False,
)

@classmethod
def _metadata_is_set(cls, record_create: RecordCreate) -> bool:
return "metadata" in record_create.__fields_set__


class UpsertRecordsBulk(CreateRecordsBulk):
async def upsert_records_bulk(
self, dataset: Dataset, bulk_upsert: RecordsBulkUpsert, raise_on_error: bool = True
) -> RecordsBulkWithUpdateInfo:
) -> RecordsBulkWithUpdatedItemIds:
found_records = await self._fetch_existing_dataset_records(dataset, bulk_upsert.items)

records = []
Expand All @@ -170,9 +167,14 @@ async def upsert_records_bulk(
external_id=record_upsert.external_id,
dataset_id=dataset.id,
)
elif self._metadata_is_set(record_upsert):
record.metadata_ = record_upsert.metadata
record.updated_at = datetime.utcnow()
else:
if record_upsert.is_set("metadata"):
record.metadata_ = record_upsert.metadata
if record_upsert.is_set("fields"):
record.fields = jsonable_encoder(record_upsert.fields)

if self._db.is_modified(record):
record.updated_at = datetime.now(UTC)

records.append(record)

Expand All @@ -186,7 +188,7 @@ async def upsert_records_bulk(
await _preload_records_relationships_before_index(self._db, records)
await self._search_engine.index_records(dataset, records)

return RecordsBulkWithUpdateInfo(
return RecordsBulkWithUpdatedItemIds(
items=records,
updated_item_ids=[record.id for record in found_records.values()],
)
Expand Down
Loading