Skip to content

Commit

Permalink
Fix performance regression for JSON tagged union (#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Nov 21, 2024
1 parent 4477692 commit e4de8a6
Showing 1 changed file with 58 additions and 75 deletions.
133 changes: 58 additions & 75 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,44 +117,6 @@ fn union_serialize<S>(
Ok(None)
}

fn tagged_union_serialize<S>(
discriminator_value: Option<Py<PyAny>>,
lookup: &HashMap<String, usize>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<Option<S>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = discriminator_value {
let tag_str = tag.to_string();
if let Some(&serializer_index) = lookup.get(&tag_str) {
let selected_serializer = &choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
}

// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, choices, retry_with_lax_check)
}

impl TypeSerializer for UnionSerializer {
fn to_python(
&self,
Expand Down Expand Up @@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
tagged_union_serialize(
self.get_discriminator_value(value, extra),
&self.lookup,
self.tagged_union_serialize(
value,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
tagged_union_serialize(
self.get_discriminator_value(key, extra),
&self.lookup,
self.tagged_union_serialize(
key,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}
Expand All @@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match tagged_union_serialize(
None,
&self.lookup,
match self.tagged_union_serialize(
value,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Expand All @@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer {
}

impl TaggedUnionSerializer {
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
fn get_discriminator_value<'py>(&self, value: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> {
let py = value.py();
let discriminator_value = match &self.discriminator {
match &self.discriminator {
Discriminator::LookupKey(lookup_key) => {
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
// at this point. we could be more strict and only do this in lax mode...
let getattr_result = match value.is_instance_of::<PyDict>() {
true => {
let value_dict = value.downcast::<PyDict>().unwrap();
lookup_key.py_get_dict_item(value_dict).ok()
}
false => lookup_key.simple_py_get_attr(value).ok(),
};
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
if let Ok(value_dict) = value.downcast::<PyDict>() {
lookup_key.py_get_dict_item(value_dict).ok().flatten()
} else {
lookup_key.simple_py_get_attr(value).ok().flatten()
}
.map(|(_, tag)| tag)
}
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
let value_str = truncate_safe_repr(value, None);
Discriminator::Function(func) => func.bind(py).call1((value,)).ok(),
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
if extra.check == SerCheck::None {
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
fn tagged_union_serialize<S>(
&self,
value: &Bound<'_, PyAny>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
) -> PyResult<Option<S>> {
if let Some(tag) = self.get_discriminator_value(value) {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let selected_serializer = &self.choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
} else if extra.check == SerCheck::None {
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise
// this warning
let value_str = truncate_safe_repr(value, None);
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
}
discriminator_value

// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, &self.choices, self.retry_with_lax_check())
}
}

0 comments on commit e4de8a6

Please sign in to comment.