diff --git a/python/dex/interop/jax/apply.py b/python/dex/interop/jax/apply.py index a698b1969..149c99240 100644 --- a/python/dex/interop/jax/apply.py +++ b/python/dex/interop/jax/apply.py @@ -202,7 +202,7 @@ def dex_call_batched(batched_args, batched_dims, func_atom): batched_dims_it = iter(batched_dims) for binder in native.argument_signature: if binder.implicit: - batched_args.append("{" + binder.name + "}") + batched_args.append("(given (" + binder.name + "))") else: ty = binder.type.dex_annotation() if next(batched_dims_it) is not batching.not_mapped: @@ -304,7 +304,7 @@ def dex_call_jvp(arg_values, arg_tangents, func_atom): # ``` # \ (given (n1)) (given (n2)) (given (n3)) (p1:ty1) (p2:ty2) (p3:ty3) (t1:ty1) (t2:ty2) (t3:ty3). # linearized = linearize(func_uncurried, (p1, p2, p3)) - # snd(linearized((t1, t2, t3))) + # snd(linearized)((t1, t2, t3)) # ``` evaluate_linearized = module.eval( f"\\ {juxt_string(implicit_args)} {juxt_arg_string(primals, name_to_ty)} {juxt_arg_string(tangents, name_to_ty)}." + @@ -366,8 +366,8 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom): # Concretely, if `f` has three primal arguments, `func_atom` should look like: # ``` # \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t0 t1 t2. - # intermediate_linearized = linearize f (x0, x1, x2) - # snd intermediate_linearized (t0, t1, t2) + # intermediate_linearized = linearize(f, (x0, x1, x2)) + # snd(intermediate_linearized)((t0, t1, t2)) # ``` # In particular, its explicit arguments are assumed to be # `num_primals` primal inputs, followed by `num_primals` tangent @@ -472,12 +472,14 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom): linearized_inputs = juxt_string(primal_params + tangent_params) # \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t1 ct. - # transpose_linear (\(t0, t2). linearized x0 x1 x2 t0 t1 t2) ct + # transpose_linear(\. + # (t0, t2) = + # linearized x0 x1 x2 t0 t1 t2)(ct) transposed = module.eval( - f"\\ {transposed_atom_params}. transpose_linear" + - f"(\\ {linear_lambda_param}." + - f"\n {tuple_unpack_string(tangent_name, tangent_inputs)}" + - f"\n {linearized_name} {linearized_inputs})(ct)\n" + f"\\ {transposed_atom_params}." + f"\n transpose_linear(\\ {linear_lambda_param}." + + f"\n {tuple_unpack_string(tangent_name, tangent_inputs)}" + + f"\n {linearized_name} {linearized_inputs})(ct)\n" ) # Tuple of cotangents relating to linear tangent inputs. In the given diff --git a/python/tests/jax_test.py b/python/tests/jax_test.py index c838dd32a..7081505ac 100644 --- a/python/tests/jax_test.py +++ b/python/tests/jax_test.py @@ -226,7 +226,7 @@ def sqr(x:(Fin n => Float)) -> Fin n => Float given (n:Nat) = def test_dex_not_knowing_shape_vmap(self): m = dex.Module(dedent(""" - def sqr(x:(Fin n => Float)) : Fin n => Float given (n:Nat) = + def sqr(x:(Fin n => Float)) -> Fin n => Float given (n:Nat) = for i. x[i] * x[i] """)) dex_sqr = primitive(m.sqr) @@ -237,7 +237,7 @@ def sqr(x:(Fin n => Float)) : Fin n => Float given (n:Nat) = def test_dex_not_knowing_shape_jvp(self): m = dex.Module(dedent(""" - def sqr(x:(Fin n => Float)) : Fin n => Float given (n:Nat) = + def sqr(x:(Fin n => Float)) -> Fin n => Float given (n:Nat) = for i. x[i] * x[i] """)) dex_sqr = primitive(m.sqr) @@ -338,7 +338,7 @@ def f_jax(x, y): def test_interleave_implicit_args_vjp(self): f_dex = primitive(dex.eval( - r'\(given (n:Nat)) x:((Fin n) => Float) {m} y:((Fin n) => (Fin m) => Float). ' + r'\(given (n:Nat)) x:((Fin n) => Float) (given (m)) y:((Fin n) => (Fin m) => Float). ' 'for i. x[i] * x[i] + 2.0 * sum(y[i])')) def f_with_dex(x, y): diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index dc0680a36..e6f8e6774 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -476,7 +476,7 @@ instance InfBuilder (InfererM i) where Abs infFrag' $ Abs b' result HoistFailure vs -> do throw EscapedNameErr $ (pprint vs) - ++ "\nFailed to exchage binders in buildAbsInf" + ++ "\nFailed to exchange binders in buildAbsInf" ++ "\n" ++ pprint infFrag Abs b e <- return ab ty' <- zonk ty diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx index 52a2f4f1e..c32378219 100644 --- a/tests/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -200,7 +200,7 @@ def passthrough(f:(a)->{|eff} b, x:a) -> {|eff} b f : (aa:Type, aa) -> aa = \bb x. myId x f Int 1 > Leaked local variables:[bb] -> Failed to exchage binders in buildAbsInf +> Failed to exchange binders in buildAbsInf > Pending emissions: > Defaults: > Solver substitution: [(_.3, bb)]