Skip to content

Commit

Permalink
Update (somehow!) the grad transformation to the new Dex syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Mar 28, 2023
1 parent a0e7664 commit c4d4af9
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions python/dex/interop/jax/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def dex_call_jvp(arg_values, arg_tangents, func_atom):
name_to_ty = {}
for binder in native.argument_signature:
if binder.implicit:
implicit_args.append("{" + binder.name + "}")
implicit_args.append("(given (" + binder.name + "))")
else:
annot = binder.type.dex_annotation()
p_name = f"p{binder.name}"
Expand All @@ -283,10 +283,16 @@ def dex_call_jvp(arg_values, arg_tangents, func_atom):
# this form. The evaluated string for three function arguments (and
# three implicit arguments) should look like:
# ```
# \ {n1} {n2} {n3} ((p1, p2, p3):(ty1 & ty2 & ty3)). func p1 p2 p3
# \ (given (n1)) (given (n2)) (given (n3)) (<fresh>:(ty1, ty2, ty3)).
# (p1, p2, p3) = fresh
# func p1 p2 p3
# ```
primal_name = api.freshName(module, 'primal')
expl_arg_string = tuple_arg_string(primal_name, primals, name_to_ty)
uncurried = module.eval(
f"\\ {juxt_string(implicit_args)} {tuple_arg_string(primals, name_to_ty)}. {func_name} {juxt_string(primals)}")
f"\\ {juxt_string(implicit_args)} {expl_arg_string}." +
f"\n {tuple_unpack_string(primal_name, primals)}" +
f"\n {func_name} {juxt_string(primals)}\n")
func_uncurried_name = uncurried.name
assert func_uncurried_name is not None

Expand All @@ -296,14 +302,14 @@ def dex_call_jvp(arg_values, arg_tangents, func_atom):
# Here we write out the tangent evaluation expression in pointful style.
# The evaluated string for three function arguments should look like:
# ```
# \ {n1} {n2} {n3} (p1:ty1) (p2:ty2) (p3:ty3) (t1:ty1) (t2:ty2) (t3:ty3).
# linearized = linearize func_uncurried (p1, p2, p3)
# snd linearized (t1, t2, t3)
# \ (given (n1)) (given (n2)) (given (n3)) (p1:ty1) (p2:ty2) (p3:ty3) (t1:ty1) (t2:ty2) (t3:ty3).
# linearized = linearize(func_uncurried, (p1, p2, p3))
# snd(linearized((t1, t2, t3)))
# ```
evaluate_linearized = module.eval(
f"\\ {juxt_string(implicit_args)} {juxt_arg_string(primals, name_to_ty)} {juxt_arg_string(tangents, name_to_ty)}." +
f"\n linearized = linearize {func_uncurried_name} {tuple_ref_string(primals)}" +
f"\n snd linearized {tuple_ref_string(tangents)}")
f"\n linearized = linearize({func_uncurried_name}, {tuple_ref_string(primals)})" +
f"\n snd(linearized)({tuple_ref_string(tangents)})\n")

# Materialize jax.ad.Zero values into actual arrays of zeros.
# TODO: Make the handling of Zeros more efficient by omitting them from the
Expand All @@ -329,21 +335,18 @@ def juxt_arg_string(names, name_to_ty):
annotated = [f"({name} : {name_to_ty[name]})" for name in names]
return juxt_string(annotated)

def tuple_arg_string(names, name_to_ty):
ty = tuple_ty_ref_string([name_to_ty[name] for name in names])
return f"({tuple_ref_string(names)} : {ty})"
def tuple_arg_string(name, names, name_to_ty):
ty = tuple_ref_string([name_to_ty[name] for name in names])
return f"({name} : {ty})"

def tuple_ref_string(names):
if len(names) == 1:
return names[0]
else:
return "(" + ", ".join(names) + ")"

def tuple_ty_ref_string(names):
if len(names) == 1:
return names[0]
else:
return "(" + " , ".join(names) + ")"
def tuple_unpack_string(name, names):
return f"{tuple_ref_string(names)} = {name}"

# === transpose ===

Expand All @@ -362,7 +365,7 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom):
# `dex_call_jvp`, applied to a some function atom, called `f`, say.
# Concretely, if `f` has three primal arguments, `func_atom` should look like:
# ```
# \ {n0} {n1} {n2} x0 x1 x2 t0 t1 t2.
# \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t0 t1 t2.
# intermediate_linearized = linearize f (x0, x1, x2)
# snd intermediate_linearized (t0, t1, t2)
# ```
Expand Down Expand Up @@ -392,7 +395,7 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom):
for binder in native.argument_signature:
if binder.implicit:
if hoistable(binder.type, name_to_ty.keys()):
implicit_args.append("{" + binder.name + "}")
implicit_args.append("(given (" + binder.name + "))")
else:
raise NotImplementedError(f"Hoist check failed: implicit {binder.name} of type {binder.type} depends on a previous explicit binder")
else:
Expand Down Expand Up @@ -441,8 +444,10 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom):
# For a three-input primal function with constant input for the tangent
# parameter at index 1, the evaluated string should look like:
# ```
# \ {n0} {n1} {n2} x0 x1 x2 t1 ct.
# transpose_linear (\(t0, t2). linearized x0 x1 x2 t0 t1 t2) ct
# \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t1 ct.
# transpose_linear(\<fresh name>.
# (t0, t2) = <fresh name>
# linearized x0 x1 x2 t0 t1 t2)(ct)
# ```
# - The `x` variables are the (constant) inputs to the primal function. These
# should always be supplied by JAX.
Expand All @@ -451,24 +456,28 @@ def dex_call_evaluate_linearized_transpose(cotangents, *args, func_atom):
# - Note that we use the original names for the parameters, and include
# their type annotations. (TODO Include a type annotation for `ct` as well?)

# {n0} {n1} {n2} x0 x1 x2 t1 ct
# (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t1 ct
transposed_atom_params = (
juxt_string(implicit_args) + " " +
juxt_arg_string(primal_params, name_to_ty) + " " +
juxt_arg_string([tangent_params[i] for i in tangent_constant_indices], name_to_ty) + " ct")

# (t0, t2)
linear_lambda_params = tuple_arg_string(
[tangent_params[i] for i in tangent_input_indices], name_to_ty)
# <fresh name>
tangent_name = api.freshName(module, 'tangent')
tangent_inputs = [tangent_params[i] for i in tangent_input_indices]
linear_lambda_param = tuple_arg_string(tangent_name,
tangent_inputs, name_to_ty)

# x0 x1 x2 t0 t1 t2
linearized_inputs = juxt_string(primal_params + tangent_params)

# \ {n0} {n1} {n2} x0 x1 x2 t1 ct.
# \ (given (n0)) (given (n1)) (given (n2)) x0 x1 x2 t1 ct.
# transpose_linear (\(t0, t2). linearized x0 x1 x2 t0 t1 t2) ct
transposed = module.eval(
f"\\ {transposed_atom_params}. transpose_linear " +
f"(\\ {linear_lambda_params}. {linearized_name} {linearized_inputs}) ct"
f"\\ {transposed_atom_params}. transpose_linear" +
f"(\\ {linear_lambda_param}." +
f"\n {tuple_unpack_string(tangent_name, tangent_inputs)}" +
f"\n {linearized_name} {linearized_inputs})(ct)\n"
)

# Tuple of cotangents relating to linear tangent inputs. In the given
Expand Down

0 comments on commit c4d4af9

Please sign in to comment.