Skip to content

Commit

Permalink
work on pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jul 22, 2023
1 parent 7d18e6e commit ebf5ee2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/obsplus/events/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _events_to_model(
catalog = Catalog(events=[catalog])
elif not isinstance(catalog, Catalog): # sequence was passed
catalog = Catalog(events=catalog)
model = event_schema.Catalog.from_orm(catalog)
model = event_schema.Catalog.model_validate(catalog)
return model


Expand Down
32 changes: 26 additions & 6 deletions src/obsplus/events/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
from uuid import uuid4

import obspy.core.event as ev
from obspy.core.util.attribdict import AttribDict
from obsplus.constants import NSLC
from pydantic import model_validator, ConfigDict, BaseModel, field_validator
from typing_extensions import Literal
from pydantic import (
model_validator,
ConfigDict,
BaseModel,
field_validator,
PlainValidator,
)
from typing_extensions import Literal, Annotated

# ----- Type Literals (enum like)

Expand Down Expand Up @@ -125,6 +132,17 @@
SourceTimeFunctionType = Literal["box car", "triangle", "trapezoid", "unknown"]


def _recursive_dict(attrib):
"""recursively turn all AttribDict s into normal dicts."""
out = dict(attrib)
for i, v in out.items():
if isinstance(v, AttribDict):
out[i] = _recursive_dict(v)
return out


AttribDictType = Annotated[AttribDict, PlainValidator(_recursive_dict)]

# ----- Type Models


Expand All @@ -133,9 +151,11 @@ class _ObsPyModel(BaseModel):
validate_assignment=True,
arbitrary_types_allowed=True,
from_attributes=True,
extra="allow",
extra="ignore",
)

# extra: Optional[AttribDictType] = None

@staticmethod
def _convert_to_obspy(value):
"""Convert an object to obspy or return value."""
Expand All @@ -149,7 +169,7 @@ def to_obspy(self):
cls = getattr(ev, name)
out = {}
# get schema and properties
schema = self.schema()
schema = self.model_json_schema()
props = schema["properties"]
array_props = {x for x, y in props.items() if y.get("type") == "array"}
# iterate each property and convert back to obspy
Expand All @@ -170,10 +190,10 @@ class ResourceIdentifier(_ObsPyModel):
@field_validator("id", mode="before")
def get_id(cls, values):
"""Get the id string from the resource id"""
value = values.get("id")
value = values.get("id") if hasattr(values, "get") else values
if value is None:
value = str(uuid4())
return {"id": value}
return value


class _ModelWithResourceID(_ObsPyModel):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_events/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,20 @@ def test_from_simple_obspy(self):

def test_from_obspy_catalog(self, test_catalog):
"""Ensure pydantic models can be generated from Obspy objects"""
out = esc.Catalog.from_orm(test_catalog)
out = esc.Catalog.model_validate(test_catalog)
assert isinstance(out, esc.Catalog)
assert len(out.events) == len(test_catalog.events)
self.assert_lens_equal(out, test_catalog)

def test_from_json(self, test_catalog):
"""Ensure the catalog can be created from json."""
catalog_dict = cat_to_dict(test_catalog)
out = esc.Catalog.parse_obj(catalog_dict)
out = esc.Catalog.model_validate(catalog_dict)
assert isinstance(out, esc.Catalog)
assert len(out.events) == len(catalog_dict["events"])

def test_round_trip(self, test_catalog):
"""Test converting from pydantic models to ObsPy."""
pycat = esc.Catalog.from_orm(test_catalog)
pycat = esc.Catalog.model_validate(test_catalog)
out = pycat.to_obspy()
assert out == test_catalog

0 comments on commit ebf5ee2

Please sign in to comment.