Skip to content

Commit

Permalink
Fix remaining tests except for one bug where transposition generates …
Browse files Browse the repository at this point in the history
…an expression which doesn't compile because of a leaked variable error.
  • Loading branch information
axch committed Mar 28, 2023
1 parent c4d4af9 commit 989f420
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
20 changes: 11 additions & 9 deletions python/dex/interop/jax/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def dex_call_batched(batched_args, batched_dims, func_atom):
batched_dims_it = iter(batched_dims)
for binder in native.argument_signature:
if binder.implicit:
batched_args.append("{" + binder.name + "}")
batched_args.append("(given (" + binder.name + "))")
else:
ty = binder.type.dex_annotation()
if next(batched_dims_it) is not batching.not_mapped:
Expand Down Expand Up @@ -304,7 +304,7 @@ def dex_call_jvp(arg_values, arg_tangents, func_atom):
# ```
# \ (given (n1)) (given (n2)) (given (n3)) (p1:ty1) (p2:ty2) (p3:ty3) (t1:ty1) (t2:ty2) (t3:ty3).
# linearized = linearize(func_uncurried, (p1, p2, p3))
# snd(linearized((t1, t2, t3)))
# snd(linearized)((t1, t2, t3))
# ```
evaluate_linearized = module.eval(
f"\\ {juxt_string(implicit_args)} {juxt_arg_string(primals, name_to_ty)} {juxt_arg_string(tangents, name_to_ty)}." +
Expand Down Expand Up @@ -366,8 +366,8 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom):
# Concretely, if `f` has three primal arguments, `func_atom` should look like:
# ```
# \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t0 t1 t2.
# intermediate_linearized = linearize f (x0, x1, x2)
# snd intermediate_linearized (t0, t1, t2)
# intermediate_linearized = linearize(f, (x0, x1, x2))
# snd(intermediate_linearized)((t0, t1, t2))
# ```
# In particular, its explicit arguments are assumed to be
# `num_primals` primal inputs, followed by `num_primals` tangent
Expand Down Expand Up @@ -472,12 +472,14 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom):
linearized_inputs = juxt_string(primal_params + tangent_params)

# \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t1 ct.
# transpose_linear (\(t0, t2). linearized x0 x1 x2 t0 t1 t2) ct
# transpose_linear(\<fresh name>.
# (t0, t2) = <fresh name>
# linearized x0 x1 x2 t0 t1 t2)(ct)
transposed = module.eval(
f"\\ {transposed_atom_params}. transpose_linear" +
f"(\\ {linear_lambda_param}." +
f"\n {tuple_unpack_string(tangent_name, tangent_inputs)}" +
f"\n {linearized_name} {linearized_inputs})(ct)\n"
f"\\ {transposed_atom_params}."
f"\n transpose_linear(\\ {linear_lambda_param}." +
f"\n {tuple_unpack_string(tangent_name, tangent_inputs)}" +
f"\n {linearized_name} {linearized_inputs})(ct)\n"
)

# Tuple of cotangents relating to linear tangent inputs. In the given
Expand Down
6 changes: 3 additions & 3 deletions python/tests/jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def sqr(x:(Fin n => Float)) -> Fin n => Float given (n:Nat) =

def test_dex_not_knowing_shape_vmap(self):
m = dex.Module(dedent("""
def sqr(x:(Fin n => Float)) : Fin n => Float given (n:Nat) =
def sqr(x:(Fin n => Float)) -> Fin n => Float given (n:Nat) =
for i. x[i] * x[i]
"""))
dex_sqr = primitive(m.sqr)
Expand All @@ -237,7 +237,7 @@ def sqr(x:(Fin n => Float)) : Fin n => Float given (n:Nat) =

def test_dex_not_knowing_shape_jvp(self):
m = dex.Module(dedent("""
def sqr(x:(Fin n => Float)) : Fin n => Float given (n:Nat) =
def sqr(x:(Fin n => Float)) -> Fin n => Float given (n:Nat) =
for i. x[i] * x[i]
"""))
dex_sqr = primitive(m.sqr)
Expand Down Expand Up @@ -338,7 +338,7 @@ def f_jax(x, y):

def test_interleave_implicit_args_vjp(self):
f_dex = primitive(dex.eval(
r'\(given (n:Nat)) x:((Fin n) => Float) {m} y:((Fin n) => (Fin m) => Float). '
r'\(given (n:Nat)) x:((Fin n) => Float) (given (m)) y:((Fin n) => (Fin m) => Float). '
'for i. x[i] * x[i] + 2.0 * sum(y[i])'))

def f_with_dex(x, y):
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ instance InfBuilder (InfererM i) where
Abs infFrag' $ Abs b' result
HoistFailure vs -> do
throw EscapedNameErr $ (pprint vs)
++ "\nFailed to exchage binders in buildAbsInf"
++ "\nFailed to exchange binders in buildAbsInf"
++ "\n" ++ pprint infFrag
Abs b e <- return ab
ty' <- zonk ty
Expand Down
2 changes: 1 addition & 1 deletion tests/uexpr-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def passthrough(f:(a)->{|eff} b, x:a) -> {|eff} b
f : (aa:Type, aa) -> aa = \bb x. myId x
f Int 1
> Leaked local variables:[bb]
> Failed to exchage binders in buildAbsInf
> Failed to exchange binders in buildAbsInf
> Pending emissions:
> Defaults:
> Solver substitution: [(_.3, bb)]
Expand Down

0 comments on commit 989f420

Please sign in to comment.