Skip to content

Commit

Permalink
fix pydantic v2 issues
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jan 4, 2024
1 parent ebf5ee2 commit 8d19f34
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies = [
"obspy >= 1.3.0",
"pandas >= 1.0",
"progressbar2",
"pydantic >= 2.0",
"pydantic >= 2.0, <3.0",
"scipy",
"tables",
"typing-extensions",
Expand Down
7 changes: 3 additions & 4 deletions src/obsplus/events/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _events_to_model(
catalog = Catalog(events=[catalog])
elif not isinstance(catalog, Catalog): # sequence was passed
catalog = Catalog(events=catalog)
model = event_schema.Catalog.model_validate(catalog)
model = event_schema.Catalog.model_validate(catalog, from_attributes=True)
return model


Expand All @@ -40,10 +40,9 @@ def cat_to_json(catalog: Union[Catalog, Event, Iterable[Event]]) -> str:

def cat_to_dict(catalog: Union[Catalog, Event, Iterable[Event]]) -> dict:
"""
Convert an event object to a
Convert an event object to a dictionary.
"""
model = _events_to_model(catalog)
return model.dict()
return _events_to_model(catalog).model_dump()


def dict_to_cat(cjson: Union[dict, str]) -> Catalog:
Expand Down
68 changes: 32 additions & 36 deletions src/obsplus/events/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@
model.
"""
from datetime import datetime
from typing import Optional, List
from typing import Optional, List, Union
from uuid import uuid4
from typing_extensions import Literal, Annotated


import obspy.core.event as ev
from obspy.core.util.attribdict import AttribDict
from obsplus.constants import NSLC
from obspy import UTCDateTime
from pydantic import (
model_validator,
ConfigDict,
BaseModel,
field_validator,
PlainValidator,
Field,
)
from typing_extensions import Literal, Annotated

from obsplus.constants import NSLC

# ----- Type Literals (enum like)

Expand Down Expand Up @@ -143,6 +146,14 @@ def _recursive_dict(attrib):

AttribDictType = Annotated[AttribDict, PlainValidator(_recursive_dict)]


def _to_datetime(dt: Union[datetime, UTCDateTime]) -> datetime:
"""Convert object to datatime."""
return UTCDateTime(dt).datetime


UTCDateTimeFormat = Annotated[UTCDateTime, PlainValidator(_to_datetime)]

# ----- Type Models


Expand All @@ -167,46 +178,31 @@ def to_obspy(self):
"""Convert to obspy objects."""
name = self.__class__.__name__
cls = getattr(ev, name)
# Note: converting to a dict is deprecated, but we don't want
# to model dump because that is recursive, so we use this
# ugly hack to just get all attributes
out = {}
# get schema and properties
schema = self.model_json_schema()
props = schema["properties"]
array_props = {x for x, y in props.items() if y.get("type") == "array"}
# iterate each property and convert back to obspy
for prop in props:
val = getattr(self, prop)
if prop in array_props:
out[prop] = [self._convert_to_obspy(x) for x in val]
for i in self.model_fields:
val = getattr(self, i)
if isinstance(val, (list, tuple)):
out[i] = [self._convert_to_obspy(x) for x in val]
else:
out[prop] = self._convert_to_obspy(val)
out[i] = self._convert_to_obspy(val)
return cls(**out)


class ResourceIdentifier(_ObsPyModel):
"""Resource ID"""

id: Optional[str] = None

@field_validator("id", mode="before")
def get_id(cls, values):
"""Get the id string from the resource id"""
value = values.get("id") if hasattr(values, "get") else values
if value is None:
value = str(uuid4())
return value
id: str = Field(default_factory=lambda: str(uuid4()))


class _ModelWithResourceID(_ObsPyModel):
"""A model which has a resource ID"""

resource_id: Optional[ResourceIdentifier] = None

@field_validator("resource_id", mode="before")
def get_resource_id(cls, value):
"""Ensure a valid str is returned."""
if value is None:
return str(uuid4())
return value
resource_id: ResourceIdentifier = Field(
default_factory=lambda: ResourceIdentifier()
)


class QuantityError(_ObsPyModel):
Expand All @@ -225,7 +221,7 @@ class CreationInfo(_ObsPyModel):
agency_uri: Optional[ResourceIdentifier] = None
author: Optional[str] = None
author_uri: Optional[ResourceIdentifier] = None
creation_time: Optional[datetime] = None
creation_time: Optional[UTCDateTimeFormat] = None
version: Optional[str] = None


Expand All @@ -234,7 +230,7 @@ class TimeWindow(_ObsPyModel):

begin: Optional[float] = None
end: Optional[float] = None
reference: Optional[datetime] = None
reference: Optional[UTCDateTimeFormat] = None


class CompositeTime(_ObsPyModel):
Expand Down Expand Up @@ -350,7 +346,7 @@ class Amplitude(_ModelWithResourceID):
pick_id: Optional[ResourceIdentifier] = None
waveform_id: Optional[WaveformStreamID] = None
filter_id: Optional[ResourceIdentifier] = None
scaling_time: Optional[datetime] = None
scaling_time: Optional[UTCDateTimeFormat] = None
scaling_time_errors: Optional[QuantityError] = None
magnitude_hint: Optional[str] = None
evaluation_mode: Optional[EvaluationMode] = None
Expand Down Expand Up @@ -394,7 +390,7 @@ class OriginQuality(_ObsPyModel):
class Pick(_ModelWithResourceID):
"""Pick"""

time: Optional[datetime] = None
time: Optional[UTCDateTimeFormat] = None
time_errors: Optional[QuantityError] = None
waveform_id: Optional[WaveformStreamID] = None
filter_id: Optional[ResourceIdentifier] = None
Expand Down Expand Up @@ -439,7 +435,7 @@ class Arrival(_ModelWithResourceID):
class Origin(_ModelWithResourceID):
"""Origin"""

time: datetime
time: UTCDateTimeFormat
time_errors: Optional[QuantityError] = None
longitude: Optional[float] = None
longitude_errors: Optional[QuantityError] = None
Expand Down
2 changes: 1 addition & 1 deletion src/obsplus/utils/stations.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _get_station_client_response(self, kwargs):
f"More than one channel returned by client with kwargs:"
f"{kwargs}, add constraints to resolve the issue"
)
raise AmbiguousResponseError(msg)
warnings.warn(msg)
return sub_inv[0][0][0].response

def _update_nrl_response(self, response, df):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_events/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ class TestResourceID:

def test_null(self):
"""Ensure Null generates a resource ID a la ObsPy"""

import obspy

obspy._debug = True

rid = esc.ResourceIdentifier()
assert isinstance(rid.id, str)
assert len(rid.id)

def test_defined_resource_id(self):
"""Ensure the defined resource_id sticks."""

rid = str(ev.ResourceIdentifier())
out = esc.ResourceIdentifier(id=rid)
assert out.id == rid
Expand Down
18 changes: 10 additions & 8 deletions tests/test_utils/test_stations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import os


import numpy as np
import obspy
import pandas as pd
Expand Down Expand Up @@ -284,11 +285,11 @@ def df_with_get_stations_kwargs(self):
# now add a row with an empty get_station_kwargs column
old = dict(df.iloc[0])
new = {
"station": "CWU",
"network": "UU",
"channel": "EHZ",
"location": "01",
"seed_id": "UU.CWU.01.EHZ",
"station": "P20A",
"network": "TA",
"channel": "BHZ",
"location": "",
"seed_id": "TA.P20A..BHZ",
"get_station_kwargs": "{}",
}
old.update(new)
Expand Down Expand Up @@ -365,11 +366,12 @@ def test_mixing_nrl_with_station_client(self, df_with_both_response_cols):
with pytest.raises(AmbiguousResponseError):
df_to_inventory(df)

def test_ambiguous_query_raises(self, df_ambiguous_client_query):
"""Ensure a query that returns multiple channels will raise."""
def test_ambiguous_query_warns(self, df_ambiguous_client_query):
"""Ensure a query that returns multiple channels will warn."""

df = df_ambiguous_client_query
with pytest.raises(AmbiguousResponseError):
msg = "More than one channel returned by client"
with pytest.warns(UserWarning, match=msg):
df_to_inventory(df)


Expand Down

0 comments on commit 8d19f34

Please sign in to comment.