Skip to content

Commit

Permalink
Separate forward and backwad compilation for default partition
Browse files Browse the repository at this point in the history
ghstack-source-id: 248de2f577fe3d61d4f2c40dc04570978dcc1543
Pull Request resolved: #856
  • Loading branch information
anjali411 committed Jun 14, 2022
1 parent 130582c commit dd0a862
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 26 deletions.
75 changes: 56 additions & 19 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:

def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
primals: List[Any], cotangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
Expand All @@ -68,21 +68,21 @@ def joint_forward_backward(
grad_primals.append(p)

# Get the outputs that need gradients
assert len(tangents) == len(outs)
assert len(cotangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
needed_cotangents = []
for out, cotangent in zip(outs, cotangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
needed_cotangents.append(cotangent)
backward_out = []
# Call the backwards pass
if grad_primals:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
grad_outputs=needed_cotangents,
allow_unused=True
)
backward_out_iter = iter(backward_out)
return outs, [
Expand Down Expand Up @@ -140,12 +140,13 @@ def create_aot_autograd_function(
compiled_fw = None
compiled_bw = None
num_outs = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}

class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
nonlocal compiled_fw, num_outs
if compiled_fw is None:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
Expand All @@ -159,31 +160,67 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
# Need it because autograd.Function disables grad in forward
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
if partition_fn is default_partition:
ctx.num_intermediate = len(fw_outs[num_outs:])
ctx.num_inputs = len(flat_tensor_args)
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
ctx.fx_g = fx_g
ctx.save_for_backward(*to_be_saved)
ctx.fwd_graph = fw_module.code
ctx.bw_graph = bw_module.code
else:
nonlocal compiled_bw
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
ctx.save_for_backward(*fw_outs[num_outs:])
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.save_for_backward(*fw_outs[num_outs:])
if partition_fn is default_partition:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
out = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
)
ctx.num_intermediate = len(fw_outs[num_outs:])
ctx.num_inputs = len(flat_tensor_args)
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
ctx.save_for_backward(*to_be_saved)
else:
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)
def backward(ctx, *flat_grad_outs):
contiguous_args = [t.contiguous() for t in flat_grad_outs]
if compiled_bw is None:
assert partition_fn is default_partition
with torch.set_grad_enabled(grad_state):
inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs]
fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args)
fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:])
assert fx_g.code == ctx.fx_g.code
f = aot_function(bw_module, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
print("INPUTS----->", *ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
print(bw_module.code)
out = f(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
return out
else:
if partition_fn is default_partition:
out = normalize_as_list(compiled_bw(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args))
else:
assert not torch.is_grad_enabled()
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)

return CompiledFunction

Expand Down
2 changes: 1 addition & 1 deletion functorch/_src/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def default_partition(
saved_values.append(user)
else:
saved_values.append(node)
saved_values = list(set(saved_values))
saved_values = list(saved_values)

return _extract_fwd_bwd_modules(joint_module, saved_values)

Expand Down
50 changes: 44 additions & 6 deletions test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,57 @@ def f(args, kwargs):

def _outs_and_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
out.sum().backward(retain_graph=True)
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
for inp in pytree.tree_flatten(inps)[0]:
inp.grad = None
diff_outs.append(out)
def full_reduce(outs):
res = 0
for out in outs:
res=res+out.sum()
return res
# print(inps)
grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True)
return outs, grads

def _outs_and_grads_and_grad_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
diff_inps = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
diff_outs.append(out)
for inp in pytree.tree_flatten(inps)[0]:
if isinstance(inp, torch.Tensor) and inp.requires_grad:
diff_inps.append(inp)
def full_reduce(outs):
res = 0
# print("entering full_reduce: ", type(outs))
for out in outs:
res=res+out.sum()
return res
print("diff_outs, diff_inps: ", diff_outs, diff_inps)
grads = torch.autograd.grad(diff_outs, diff_inps, create_graph=True)
# print("grad call with: ", full_reduce(diff_outs), diff_inps)
diff_grads = []
for grad_ in grads:
if isinstance(grad_, torch.Tensor) and grad_.requires_grad:
diff_grads.append(grad_)
# print("grad grad call with: ", grads, full_reduce(diff_grads), diff_inps)
grad_grads = torch.autograd.grad(diff_grads, diff_inps)
return outs, grads, grad_grads

class TestAOTAutograd(TestCase):
def verify_aot_autograd(self, f, inp):
if isinstance(f, nn.Module):
compiled_f = aot_module(f, nop)
else:
compiled_f = aot_function(f, nop)
ref_out, ref_grad = _outs_and_grads(f, inp)
test_out, test_grad = _outs_and_grads(compiled_f, inp)
ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp)
test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp)
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)
self.assertEqual(ref_grad_grad, test_grad_grad)

def test_single_output(self):
def f(a, b):
Expand All @@ -284,6 +316,12 @@ def f(a, b):
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)

def test_cube(self):
def f(a):
return a ** 3
inp = [torch.tensor(2.3, requires_grad=True)]
self.verify_aot_autograd(f, inp)

def test_no_grad_input_output(self):
def f(a, b):
return a.cos(), b.cos(), a * b
Expand Down

0 comments on commit dd0a862

Please sign in to comment.