diff --git a/CHANGES.rst b/CHANGES.rst index 89e76af..1a6cc14 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,8 @@ Changes - Allow to use the package with Python 3.13 -- Caution: No security audit has been done so far. +- Add support for the matmul (``@``) operator. + 7.0 (2023-11-17) ---------------- diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index c6e2e78..66ae50f 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -768,8 +768,8 @@ def visit_BitAnd(self, node): return self.node_contents_visit(node) def visit_MatMult(self, node): - """Matrix multiplication (`@`) is currently not allowed.""" - self.not_allowed(node) + """Allow multiplication (`@`).""" + return self.node_contents_visit(node) def visit_BoolOp(self, node): """Allow bool operator without restrictions.""" diff --git a/tests/transformer/operators/test_arithmetic_operators.py b/tests/transformer/operators/test_arithmetic_operators.py index 209eb50..0d5722b 100644 --- a/tests/transformer/operators/test_arithmetic_operators.py +++ b/tests/transformer/operators/test_arithmetic_operators.py @@ -1,4 +1,3 @@ -from RestrictedPython import compile_restricted_eval from tests.helper import restricted_eval @@ -33,8 +32,12 @@ def test_FloorDiv(): def test_MatMult(): - result = compile_restricted_eval('(8, 3, 5) @ (2, 7, 1)') - assert result.errors == ( - 'Line None: MatMult statements are not allowed.', - ) - assert result.code is None + class Vector: + def __init__(self, values): + self.values = values + + def __matmul__(self, other): + return sum(x * y for x, y in zip(self.values, other.values)) + + assert restricted_eval( + 'Vector((8, 3, 5)) @ Vector((2, 7, 1))', {'Vector': Vector}) == 42