Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for unpacked TypedDict to type hint variadic keyword arguments in ArgumentsValidator #1451

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3372,11 +3372,15 @@ def arguments_parameter(
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)


VarKwargsMode: TypeAlias = Literal['uniform', 'unpacked-typed-dict']


class ArgumentsSchema(TypedDict, total=False):
type: Required[Literal['arguments']]
arguments_schema: Required[List[ArgumentsParameter]]
populate_by_name: bool
var_args_schema: CoreSchema
var_kwargs_mode: VarKwargsMode
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
var_kwargs_schema: CoreSchema
ref: str
metadata: Dict[str, Any]
Expand All @@ -3388,6 +3392,7 @@ def arguments_schema(
*,
populate_by_name: bool | None = None,
var_args_schema: CoreSchema | None = None,
var_kwargs_mode: VarKwargsMode | None = None,
var_kwargs_schema: CoreSchema | None = None,
ref: str | None = None,
metadata: Dict[str, Any] | None = None,
Expand All @@ -3414,6 +3419,9 @@ def arguments_schema(
arguments: The arguments to use for the arguments schema
populate_by_name: Whether to populate by name
var_args_schema: The variable args schema to use for the arguments schema
var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'uniform'`, every value of the
keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`,
the `var_kwargs_schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema]
var_kwargs_schema: The variable kwargs schema to use for the arguments schema
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3424,6 +3432,7 @@ def arguments_schema(
arguments_schema=arguments,
populate_by_name=populate_by_name,
var_args_schema=var_args_schema,
var_kwargs_mode=var_kwargs_mode,
var_kwargs_schema=var_kwargs_schema,
ref=ref,
metadata=metadata,
Expand Down
108 changes: 87 additions & 21 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::FromStr;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
Expand All @@ -15,6 +17,27 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, PartialEq)]
enum VarKwargsMode {
Uniform,
UnpackedTypedDict,
}

impl FromStr for VarKwargsMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"uniform" => Ok(Self::Uniform),
"unpacked-typed-dict" => Ok(Self::UnpackedTypedDict),
s => py_schema_err!(
"Invalid var_kwargs mode: `{}`, expected `uniform` or `unpacked-typed-dict`",
s
),
}
}
}

#[derive(Debug)]
struct Parameter {
positional: bool,
Expand All @@ -29,6 +52,7 @@ pub struct ArgumentsValidator {
parameters: Vec<Parameter>,
positional_params_count: usize,
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_mode: VarKwargsMode,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
Expand Down Expand Up @@ -117,17 +141,31 @@ impl BuildValidator for ArgumentsValidator {
});
}

let py_var_kwargs_mode: Bound<PyString> = schema
.get_as(intern!(py, "var_kwargs_mode"))?
.unwrap_or_else(|| PyString::new_bound(py, "uniform"));

let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_str()?)?;
let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
};
Viicos marked this conversation as resolved.
Show resolved Hide resolved

if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should also check that var_kwargs_validator is a TypedDictValidator?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to, given that you have the conditional checks in pydantic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it a TypedDictValdiator statically (at build time) would allow for efficiency gains (by avoiding the dispatch via CombinedValidator).

But I wonder, are there cases where it can be wrapped in a function-after validator (e.g. model_validator)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it a TypedDictValdiator statically (at build time) would allow for efficiency gains (by avoiding the dispatch via CombinedValidator).

The thing is (as described here) var_kwargs_validator is used for both uniform and unpacked-typed-dict modes.

But I wonder, are there cases where it can be wrapped in a function-after validator (e.g. model_validator)?

only config can be attached to typed dicts iirc, and it still results in a typed dict schema.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have

#[derive(Debug, PartialEq)]
enum VarKwargsMode {
    Uniform(CombinedValidator),
    UnpackedTypedDict(TypedDictValidator),
}

... and replace the FromStr implementation with a bespoke function like from_string_and_validator or similar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will leave this as an optimization for later, going to approve for now 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1457

return py_schema_err!(
"`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
);
}

Ok(Self {
parameters,
positional_params_count,
var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
Viicos marked this conversation as resolved.
Show resolved Hide resolved
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_mode,
var_kwargs_validator,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
Expand Down Expand Up @@ -255,6 +293,9 @@ impl Validator for ArgumentsValidator {
}
}
}

let remaining_kwargs = PyDict::new_bound(py);

// if there are kwargs check any that haven't been processed yet
if let Some(kwargs) = args.kwargs() {
if kwargs.len() > used_kwargs.len() {
Expand All @@ -278,33 +319,58 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
};
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
match self.var_kwargs_mode {
VarKwargsMode::Uniform => match &self.var_kwargs_validator {
Some(validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
VarKwargsMode::UnpackedTypedDict => {
// Save to the remaining kwargs, we will validate as a single dict:
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
}
}
}

if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict {
// `var_kwargs_validator` is guaranteed to be `Some`:
match self
.var_kwargs_validator
.as_ref()
.unwrap()
.validate(py, remaining_kwargs.as_any(), state)
{
Ok(value) => {
output_kwargs.update(value.downcast_bound::<PyDict>(py).unwrap().as_mapping())?;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(line_errors);
}
Err(err) => return Err(err),
}
}

if !errors.is_empty() {
Err(ValError::LineErrors(errors))
} else {
Expand Down
57 changes: 56 additions & 1 deletion tests/validators/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,19 @@ def test_build_non_default_follows():
)


def test_build_missing_var_kwargs():
with pytest.raises(
SchemaError, match="`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
):
SchemaValidator(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
}
)


@pytest.mark.parametrize(
'input_value,expected',
[
Expand All @@ -778,7 +791,7 @@ def test_build_non_default_follows():
],
ids=repr,
)
def test_kwargs(py_and_json: PyAndJson, input_value, expected):
def test_kwargs_uniform(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
Expand All @@ -796,6 +809,48 @@ def test_kwargs(py_and_json: PyAndJson, input_value, expected):
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
[ArgsKwargs((), {'x': 1}), ((), {'x': 1})],
[ArgsKwargs((), {'x': 1.0}), Err('x\n Input should be a valid integer [type=int_type,')],
[ArgsKwargs((), {'x': 1, 'z': 'str'}), ((), {'x': 1, 'y': 'str'})],
[ArgsKwargs((), {'x': 1, 'y': 'str'}), Err('y\n Extra inputs are not permitted [type=extra_forbidden,')],
],
)
def test_kwargs_typed_dict(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
'var_kwargs_schema': {
'type': 'typed-dict',
'fields': {
'x': {
'type': 'typed-dict-field',
'schema': {'type': 'int', 'strict': True},
'required': True,
},
'y': {
'type': 'typed-dict-field',
'schema': {'type': 'str'},
'required': False,
'validation_alias': 'z',
},
},
'config': {'extra_fields_behavior': 'forbid'},
},
}
)

if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
else:
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
Expand Down
Loading