Skip to content

Commit

Permalink
Add support for unpacked TypedDict to type hint variadic keyword ar…
Browse files Browse the repository at this point in the history
…guments in `ArgumentsValidator` (#1451)

Co-authored-by: Sydney Runkle <[email protected]>
Co-authored-by: David Hewitt <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2024
1 parent 8c1a0da commit f3f436e
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 22 deletions.
9 changes: 9 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3372,11 +3372,15 @@ def arguments_parameter(
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)


VarKwargsMode: TypeAlias = Literal['uniform', 'unpacked-typed-dict']


class ArgumentsSchema(TypedDict, total=False):
type: Required[Literal['arguments']]
arguments_schema: Required[List[ArgumentsParameter]]
populate_by_name: bool
var_args_schema: CoreSchema
var_kwargs_mode: VarKwargsMode
var_kwargs_schema: CoreSchema
ref: str
metadata: Dict[str, Any]
Expand All @@ -3388,6 +3392,7 @@ def arguments_schema(
*,
populate_by_name: bool | None = None,
var_args_schema: CoreSchema | None = None,
var_kwargs_mode: VarKwargsMode | None = None,
var_kwargs_schema: CoreSchema | None = None,
ref: str | None = None,
metadata: Dict[str, Any] | None = None,
Expand All @@ -3414,6 +3419,9 @@ def arguments_schema(
arguments: The arguments to use for the arguments schema
populate_by_name: Whether to populate by name
var_args_schema: The variable args schema to use for the arguments schema
var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'uniform'`, every value of the
keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`,
the `var_kwargs_schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema]
var_kwargs_schema: The variable kwargs schema to use for the arguments schema
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
Expand All @@ -3424,6 +3432,7 @@ def arguments_schema(
arguments_schema=arguments,
populate_by_name=populate_by_name,
var_args_schema=var_args_schema,
var_kwargs_mode=var_kwargs_mode,
var_kwargs_schema=var_kwargs_schema,
ref=ref,
metadata=metadata,
Expand Down
108 changes: 87 additions & 21 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::FromStr;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
Expand All @@ -15,6 +17,27 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, PartialEq)]
enum VarKwargsMode {
Uniform,
UnpackedTypedDict,
}

impl FromStr for VarKwargsMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"uniform" => Ok(Self::Uniform),
"unpacked-typed-dict" => Ok(Self::UnpackedTypedDict),
s => py_schema_err!(
"Invalid var_kwargs mode: `{}`, expected `uniform` or `unpacked-typed-dict`",
s
),
}
}
}

#[derive(Debug)]
struct Parameter {
positional: bool,
Expand All @@ -29,6 +52,7 @@ pub struct ArgumentsValidator {
parameters: Vec<Parameter>,
positional_params_count: usize,
var_args_validator: Option<Box<CombinedValidator>>,
var_kwargs_mode: VarKwargsMode,
var_kwargs_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
extra: ExtraBehavior,
Expand Down Expand Up @@ -117,17 +141,31 @@ impl BuildValidator for ArgumentsValidator {
});
}

let py_var_kwargs_mode: Bound<PyString> = schema
.get_as(intern!(py, "var_kwargs_mode"))?
.unwrap_or_else(|| PyString::new_bound(py, "uniform"));

let var_kwargs_mode = VarKwargsMode::from_str(py_var_kwargs_mode.to_str()?)?;
let var_kwargs_validator = match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
};

if var_kwargs_mode == VarKwargsMode::UnpackedTypedDict && var_kwargs_validator.is_none() {
return py_schema_err!(
"`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
);
}

Ok(Self {
parameters,
positional_params_count,
var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? {
Some(v) => Some(Box::new(build_validator(&v, config, definitions)?)),
None => None,
},
var_kwargs_mode,
var_kwargs_validator,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?,
}
Expand Down Expand Up @@ -255,6 +293,9 @@ impl Validator for ArgumentsValidator {
}
}
}

let remaining_kwargs = PyDict::new_bound(py);

// if there are kwargs check any that haven't been processed yet
if let Some(kwargs) = args.kwargs() {
if kwargs.len() > used_kwargs.len() {
Expand All @@ -278,33 +319,58 @@ impl Validator for ArgumentsValidator {
Err(err) => return Err(err),
};
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
match self.var_kwargs_mode {
VarKwargsMode::Uniform => match &self.var_kwargs_validator {
Some(validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_kwargs
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
}
Err(err) => return Err(err),
},
None => {
if let ExtraBehavior::Forbid = self.extra {
errors.push(ValLineError::new_with_loc(
ErrorTypeDefaults::UnexpectedKeywordArgument,
value,
raw_key.clone(),
));
}
VarKwargsMode::UnpackedTypedDict => {
// Save to the remaining kwargs, we will validate as a single dict:
remaining_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
}
}
}

if self.var_kwargs_mode == VarKwargsMode::UnpackedTypedDict {
// `var_kwargs_validator` is guaranteed to be `Some`:
match self
.var_kwargs_validator
.as_ref()
.unwrap()
.validate(py, remaining_kwargs.as_any(), state)
{
Ok(value) => {
output_kwargs.update(value.downcast_bound::<PyDict>(py).unwrap().as_mapping())?;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(line_errors);
}
Err(err) => return Err(err),
}
}

if !errors.is_empty() {
Err(ValError::LineErrors(errors))
} else {
Expand Down
57 changes: 56 additions & 1 deletion tests/validators/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,19 @@ def test_build_non_default_follows():
)


def test_build_missing_var_kwargs():
with pytest.raises(
SchemaError, match="`var_kwargs_schema` must be specified when `var_kwargs_mode` is `'unpacked-typed-dict'`"
):
SchemaValidator(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
}
)


@pytest.mark.parametrize(
'input_value,expected',
[
Expand All @@ -778,7 +791,7 @@ def test_build_non_default_follows():
],
ids=repr,
)
def test_kwargs(py_and_json: PyAndJson, input_value, expected):
def test_kwargs_uniform(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
Expand All @@ -796,6 +809,48 @@ def test_kwargs(py_and_json: PyAndJson, input_value, expected):
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
[ArgsKwargs((), {'x': 1}), ((), {'x': 1})],
[ArgsKwargs((), {'x': 1.0}), Err('x\n Input should be a valid integer [type=int_type,')],
[ArgsKwargs((), {'x': 1, 'z': 'str'}), ((), {'x': 1, 'y': 'str'})],
[ArgsKwargs((), {'x': 1, 'y': 'str'}), Err('y\n Extra inputs are not permitted [type=extra_forbidden,')],
],
)
def test_kwargs_typed_dict(py_and_json: PyAndJson, input_value, expected):
v = py_and_json(
{
'type': 'arguments',
'arguments_schema': [],
'var_kwargs_mode': 'unpacked-typed-dict',
'var_kwargs_schema': {
'type': 'typed-dict',
'fields': {
'x': {
'type': 'typed-dict-field',
'schema': {'type': 'int', 'strict': True},
'required': True,
},
'y': {
'type': 'typed-dict-field',
'schema': {'type': 'str'},
'required': False,
'validation_alias': 'z',
},
},
'config': {'extra_fields_behavior': 'forbid'},
},
}
)

if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
else:
assert v.validate_test(input_value) == expected


@pytest.mark.parametrize(
'input_value,expected',
[
Expand Down

0 comments on commit f3f436e

Please sign in to comment.