From ebf5ee2ff2a106122205671e08497150bd73d2d8 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Fri, 21 Jul 2023 21:31:38 -0600 Subject: [PATCH] work on pydantic v2 --- src/obsplus/events/json.py | 2 +- src/obsplus/events/schema.py | 32 ++++++++++++++++++++++++++------ tests/test_events/test_schema.py | 6 +++--- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/obsplus/events/json.py b/src/obsplus/events/json.py index d140757d..eb468c9c 100644 --- a/src/obsplus/events/json.py +++ b/src/obsplus/events/json.py @@ -26,7 +26,7 @@ def _events_to_model( catalog = Catalog(events=[catalog]) elif not isinstance(catalog, Catalog): # sequence was passed catalog = Catalog(events=catalog) - model = event_schema.Catalog.from_orm(catalog) + model = event_schema.Catalog.model_validate(catalog) return model diff --git a/src/obsplus/events/schema.py b/src/obsplus/events/schema.py index 649087db..70b5454c 100644 --- a/src/obsplus/events/schema.py +++ b/src/obsplus/events/schema.py @@ -9,9 +9,16 @@ from uuid import uuid4 import obspy.core.event as ev +from obspy.core.util.attribdict import AttribDict from obsplus.constants import NSLC -from pydantic import model_validator, ConfigDict, BaseModel, field_validator -from typing_extensions import Literal +from pydantic import ( + model_validator, + ConfigDict, + BaseModel, + field_validator, + PlainValidator, +) +from typing_extensions import Literal, Annotated # ----- Type Literals (enum like) @@ -125,6 +132,17 @@ SourceTimeFunctionType = Literal["box car", "triangle", "trapezoid", "unknown"] +def _recursive_dict(attrib): + """recursively turn all AttribDict s into normal dicts.""" + out = dict(attrib) + for i, v in out.items(): + if isinstance(v, AttribDict): + out[i] = _recursive_dict(v) + return out + + +AttribDictType = Annotated[AttribDict, PlainValidator(_recursive_dict)] + # ----- Type Models @@ -133,9 +151,11 @@ class _ObsPyModel(BaseModel): validate_assignment=True, arbitrary_types_allowed=True, from_attributes=True, - extra="allow", + extra="ignore", ) + # extra: Optional[AttribDictType] = None + @staticmethod def _convert_to_obspy(value): """Convert an object to obspy or return value.""" @@ -149,7 +169,7 @@ def to_obspy(self): cls = getattr(ev, name) out = {} # get schema and properties - schema = self.schema() + schema = self.model_json_schema() props = schema["properties"] array_props = {x for x, y in props.items() if y.get("type") == "array"} # iterate each property and convert back to obspy @@ -170,10 +190,10 @@ class ResourceIdentifier(_ObsPyModel): @field_validator("id", mode="before") def get_id(cls, values): """Get the id string from the resource id""" - value = values.get("id") + value = values.get("id") if hasattr(values, "get") else values if value is None: value = str(uuid4()) - return {"id": value} + return value class _ModelWithResourceID(_ObsPyModel): diff --git a/tests/test_events/test_schema.py b/tests/test_events/test_schema.py index 2c6c48f2..78a83901 100644 --- a/tests/test_events/test_schema.py +++ b/tests/test_events/test_schema.py @@ -90,7 +90,7 @@ def test_from_simple_obspy(self): def test_from_obspy_catalog(self, test_catalog): """Ensure pydantic models can be generated from Obspy objects""" - out = esc.Catalog.from_orm(test_catalog) + out = esc.Catalog.model_validate(test_catalog) assert isinstance(out, esc.Catalog) assert len(out.events) == len(test_catalog.events) self.assert_lens_equal(out, test_catalog) @@ -98,12 +98,12 @@ def test_from_obspy_catalog(self, test_catalog): def test_from_json(self, test_catalog): """Ensure the catalog can be created from json.""" catalog_dict = cat_to_dict(test_catalog) - out = esc.Catalog.parse_obj(catalog_dict) + out = esc.Catalog.model_validate(catalog_dict) assert isinstance(out, esc.Catalog) assert len(out.events) == len(catalog_dict["events"]) def test_round_trip(self, test_catalog): """Test converting from pydantic models to ObsPy.""" - pycat = esc.Catalog.from_orm(test_catalog) + pycat = esc.Catalog.model_validate(test_catalog) out = pycat.to_obspy() assert out == test_catalog