From 58e2526c59587dbc8ad01b6f0837ed74ff8227c5 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 7 Jun 2022 19:16:33 +0000 Subject: [PATCH 1/8] Separate forward and backwad compilation for default partition [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 57a7ac68f..6265509e1 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -53,6 +53,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: def create_joint_forward_backward(fn): + # tangents are just grad_outs/cotangents (wrong naming) def joint_forward_backward( primals: List[Any], tangents: List[Any] ) -> Tuple[List[Any], List[Any]]: @@ -140,12 +141,14 @@ def create_aot_autograd_function( compiled_fw = None compiled_bw = None num_outs = None - + joint_inputs = None + fw_outs = None + aot_decompositions = {**aot_autograd_decompositions, **decompositions} class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, compiled_bw, num_outs + nonlocal compiled_fw, num_outs, joint_inputs, fw_outs if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -159,19 +162,19 @@ def forward(ctx, *flat_tensor_args): num_outs = 1 joint_inputs = (flat_tensor_args, out) - aot_decompositions = {**aot_autograd_decompositions, **decompositions} + # Need it because autograd.Function disables grad in forward with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)( *joint_inputs ) fw_module, bw_module = partition_fn(fx_g, joint_inputs) - # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - - bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] - compiled_bw = bw_compiler(bw_module, bw_args) + if partition_fn is default_partition: + nonlocal compiled_bw + bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) ctx.save_for_backward(*fw_outs[num_outs:]) @@ -179,9 +182,14 @@ def forward(ctx, *flat_tensor_args): @staticmethod @disable_torchdynamo - def backward(ctx, *flat_args): - contiguous_args = [t.contiguous() for t in flat_args] - # contiguous_args = [t for t in flat_args] + def backward(ctx, *flat_grad_outs): + nonlocal compiled_bw + contiguous_args = [t.contiguous() for t in flat_grad_outs] + if compiled_bw is None: + with torch.set_grad_enabled(grad_state): + fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args) + fw_module, bw_module = partition_fn(fx_g, joint_inputs) + compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args) out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) return tuple(out) From 5c8248e95d1e10c1cc90a8b7a0c2e111aac0aa1a Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 7 Jun 2022 19:28:10 +0000 Subject: [PATCH 2/8] Update on "Separate forward and backwad compilation for default partition" [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 6265509e1..268c46c85 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -53,9 +53,8 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: def create_joint_forward_backward(fn): - # tangents are just grad_outs/cotangents (wrong naming) def joint_forward_backward( - primals: List[Any], tangents: List[Any] + primals: List[Any], cotangents: List[Any] ) -> Tuple[List[Any], List[Any]]: # Call the forward pass outs = fn(*primals) @@ -69,20 +68,20 @@ def joint_forward_backward( grad_primals.append(p) # Get the outputs that need gradients - assert len(tangents) == len(outs) + assert len(cotangents) == len(outs) needed_outs = [] - needed_tangents = [] - for out, tangent in zip(outs, tangents): + needed_cotangents = [] + for out, cotangent in zip(outs, cotangents): if isinstance(out, Tensor) and out.requires_grad: needed_outs.append(out) - needed_tangents.append(tangent) + needed_cotangents.append(cotangent) backward_out = [] # Call the backwards pass if grad_primals: backward_out = torch.autograd.grad( needed_outs, grad_primals, - grad_outputs=needed_tangents, + grad_outputs=needed_cotangents, allow_unused=True, ) backward_out_iter = iter(backward_out) From b24809346af1d412c45ffc1d0e5060bf63fc24ab Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 7 Jun 2022 19:35:04 +0000 Subject: [PATCH 3/8] Update on "Separate forward and backwad compilation for default partition" Test Plan: Existing tests should pass [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 268c46c85..01becfcc5 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -143,6 +143,7 @@ def create_aot_autograd_function( joint_inputs = None fw_outs = None aot_decompositions = {**aot_autograd_decompositions, **decompositions} + class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo From 4dd0e61fd6dff83b23046d46985455da5fcf0776 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Thu, 9 Jun 2022 17:16:25 +0000 Subject: [PATCH 4/8] Update on "Separate forward and backwad compilation for default partition" Test Plan: Existing tests should pass [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 67 ++++++++++++++++++++++++++++------ test/test_pythonkey.py | 41 ++++++++++++++++++--- 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 01becfcc5..c4809b7c7 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -82,7 +82,7 @@ def joint_forward_backward( needed_outs, grad_primals, grad_outputs=needed_cotangents, - allow_unused=True, + allow_unused=True ) backward_out_iter = iter(backward_out) return outs, [ @@ -140,15 +140,13 @@ def create_aot_autograd_function( compiled_fw = None compiled_bw = None num_outs = None - joint_inputs = None - fw_outs = None aot_decompositions = {**aot_autograd_decompositions, **decompositions} class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, num_outs, joint_inputs, fw_outs + nonlocal compiled_fw, num_outs if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -172,26 +170,73 @@ def forward(ctx, *flat_tensor_args): compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) if partition_fn is default_partition: + print("ENTERING default_partition") + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out + print("fw outs: ", fw_outs, "-------") + ctx.save_for_backward(*to_be_saved) + ctx.fwd_graph = fw_module.code + else: nonlocal compiled_bw bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] compiled_bw = bw_compiler(bw_module, bw_args) + ctx.save_for_backward(*fw_outs[num_outs:]) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - ctx.save_for_backward(*fw_outs[num_outs:]) + if partition_fn is default_partition: + with torch.set_grad_enabled(grad_state): + out = flat_fn(*flat_tensor_args) + out = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out + ) + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out + ctx.save_for_backward(*to_be_saved) + else: + ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_grad_outs): - nonlocal compiled_bw + print(flat_grad_outs) contiguous_args = [t.contiguous() for t in flat_grad_outs] if compiled_bw is None: + assert partition_fn is default_partition with torch.set_grad_enabled(grad_state): - fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args) - fw_module, bw_module = partition_fn(fx_g, joint_inputs) - compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args) - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) - return tuple(out) + inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] + fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args) + # assert that the forward graph generated here is the same + # if it's specified that the user might want to calculate double backwards + fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:]) + print(fw_module.code) + print(ctx.fwd_graph) + assert fw_module.code == ctx.fwd_graph + func_code = bw_module.code.split('self, ') + # print(func_code[0] + func_code[1]) + exec(func_code[0] + func_code[1], globals()) + f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state) + # print(bw_module.code, *ctx.saved_tensors, contiguous_args) + # print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + # print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args) + return f.apply(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + else: + assert not torch.is_grad_enabled() + out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + return tuple(out) + # nonlocal compiled_bw + # contiguous_args = [t.contiguous() for t in flat_grad_outs] + # if compiled_bw is None: + # with torch.set_grad_enabled(grad_state): + # fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args) + # # assert that the forward graph generated here is the same + # # if it's specified that the user might want to calculate double backwards + # fw_module, bw_module = partition_fn(fx_g, joint_inputs) + # compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args) + # out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + # return tuple(out) return CompiledFunction diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index ae399fc81..faf0a55de 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -246,14 +246,42 @@ def f(args, kwargs): def _outs_and_grads(fn, inps): outs = fn(*inps) + diff_outs = [] for out in pytree.tree_flatten(outs)[0]: if isinstance(out, torch.Tensor) and out.requires_grad: - out.sum().backward(retain_graph=True) - grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]] - for inp in pytree.tree_flatten(inps)[0]: - inp.grad = None + diff_outs.append(out) + def full_reduce(outs): + res = 0 + for out in outs: + res=res+out.sum() + return res + print(inps) + grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True) return outs, grads +def _outs_and_grads_and_grad_grads(fn, inps): + outs = fn(*inps) + diff_outs = [] + diff_inps = [] + for out in pytree.tree_flatten(outs)[0]: + if isinstance(out, torch.Tensor) and out.requires_grad: + diff_outs.append(out) + for inp in pytree.tree_flatten(inps)[0]: + if isinstance(inp, torch.Tensor) and inp.requires_grad: + diff_inps.append(inp) + def full_reduce(outs): + res = 0 + for out in outs: + res=res+out.sum() + return res + grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, create_graph=True) + print("grads: ", grads) + diff_grads = [] + for grad_ in grads: + if isinstance(grad_, torch.Tensor) and grad_.requires_grad: + diff_grads.append(grad_) + grad_grads = torch.autograd.grad(full_reduce(diff_grads), diff_inps) + return outs, grads, grad_grads class TestAOTAutograd(TestCase): def verify_aot_autograd(self, f, inp): @@ -261,10 +289,11 @@ def verify_aot_autograd(self, f, inp): compiled_f = aot_module(f, nop) else: compiled_f = aot_function(f, nop) - ref_out, ref_grad = _outs_and_grads(f, inp) - test_out, test_grad = _outs_and_grads(compiled_f, inp) + ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp) + test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad) + # self.assertEqual(ref_grad_grad, test_grad_grad) def test_single_output(self): def f(a, b): From 72277781a94ab07b123e0ac5fbe3b776da9dabda Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 14 Jun 2022 19:20:44 +0000 Subject: [PATCH 5/8] Update on "Separate forward and backwad compilation for default partition" Test Plan: Existing tests should pass [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 42 +++++++++++----------------------- functorch/_src/partitioners.py | 2 +- test/test_pythonkey.py | 19 +++++++++++---- 3 files changed, 28 insertions(+), 35 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index c4809b7c7..12d88f86f 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -170,13 +170,13 @@ def forward(ctx, *flat_tensor_args): compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) if partition_fn is default_partition: - print("ENTERING default_partition") ctx.num_intermediate = len(fw_outs[num_outs:]) ctx.num_inputs = len(flat_tensor_args) to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out - print("fw outs: ", fw_outs, "-------") + ctx.fx_g = fx_g ctx.save_for_backward(*to_be_saved) ctx.fwd_graph = fw_module.code + ctx.bw_graph = bw_module.code else: nonlocal compiled_bw bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] @@ -201,42 +201,26 @@ def forward(ctx, *flat_tensor_args): @staticmethod @disable_torchdynamo def backward(ctx, *flat_grad_outs): - print(flat_grad_outs) contiguous_args = [t.contiguous() for t in flat_grad_outs] if compiled_bw is None: assert partition_fn is default_partition with torch.set_grad_enabled(grad_state): inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args) - # assert that the forward graph generated here is the same - # if it's specified that the user might want to calculate double backwards fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:]) - print(fw_module.code) - print(ctx.fwd_graph) - assert fw_module.code == ctx.fwd_graph - func_code = bw_module.code.split('self, ') - # print(func_code[0] + func_code[1]) - exec(func_code[0] + func_code[1], globals()) - f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state) - # print(bw_module.code, *ctx.saved_tensors, contiguous_args) - # print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) - # print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args) - return f.apply(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + assert fx_g.code == ctx.fx_g.code + f = aot_function(bw_module, bw_compiler, bw_compiler, partition_fn, aot_decompositions) + print("INPUTS----->", *ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + print(bw_module.code) + out = f(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + return out else: - assert not torch.is_grad_enabled() - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + if partition_fn is default_partition: + out = normalize_as_list(compiled_bw(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)) + else: + assert not torch.is_grad_enabled() + out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) return tuple(out) - # nonlocal compiled_bw - # contiguous_args = [t.contiguous() for t in flat_grad_outs] - # if compiled_bw is None: - # with torch.set_grad_enabled(grad_state): - # fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args) - # # assert that the forward graph generated here is the same - # # if it's specified that the user might want to calculate double backwards - # fw_module, bw_module = partition_fn(fx_g, joint_inputs) - # compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args) - # out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) - # return tuple(out) return CompiledFunction diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 550e2b7a4..755502f9c 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -153,7 +153,7 @@ def default_partition( saved_values.append(user) else: saved_values.append(node) - saved_values = list(set(saved_values)) + saved_values = list(saved_values) return _extract_fwd_bwd_modules(joint_module, saved_values) diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index faf0a55de..7ec2868da 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -255,7 +255,7 @@ def full_reduce(outs): for out in outs: res=res+out.sum() return res - print(inps) + # print(inps) grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True) return outs, grads @@ -271,16 +271,19 @@ def _outs_and_grads_and_grad_grads(fn, inps): diff_inps.append(inp) def full_reduce(outs): res = 0 + # print("entering full_reduce: ", type(outs)) for out in outs: res=res+out.sum() return res - grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, create_graph=True) - print("grads: ", grads) + print("diff_outs, diff_inps: ", diff_outs, diff_inps) + grads = torch.autograd.grad(diff_outs, diff_inps, create_graph=True) + # print("grad call with: ", full_reduce(diff_outs), diff_inps) diff_grads = [] for grad_ in grads: if isinstance(grad_, torch.Tensor) and grad_.requires_grad: diff_grads.append(grad_) - grad_grads = torch.autograd.grad(full_reduce(diff_grads), diff_inps) + # print("grad grad call with: ", grads, full_reduce(diff_grads), diff_inps) + grad_grads = torch.autograd.grad(diff_grads, diff_inps) return outs, grads, grad_grads class TestAOTAutograd(TestCase): @@ -293,7 +296,7 @@ def verify_aot_autograd(self, f, inp): test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad) - # self.assertEqual(ref_grad_grad, test_grad_grad) + self.assertEqual(ref_grad_grad, test_grad_grad) def test_single_output(self): def f(a, b): @@ -313,6 +316,12 @@ def f(a, b): inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) + def test_cube(self): + def f(a): + return a ** 3 + inp = [torch.tensor(2.3, requires_grad=True)] + self.verify_aot_autograd(f, inp) + def test_no_grad_input_output(self): def f(a, b): return a.cos(), b.cos(), a * b From b8358452c61dcc98da7df66bc7f637c896f475d6 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 21 Jun 2022 03:00:47 +0000 Subject: [PATCH 6/8] Update on "Separate forward and backwad compilation for default partition" Test Plan: Existing tests should pass [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 131 ++++++++++++++++++--------------- functorch/_src/partitioners.py | 19 ++++- test/test_compile_cache.py | 54 ++++++++------ test/test_pythonkey.py | 59 ++++++++++----- 4 files changed, 159 insertions(+), 104 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 12d88f86f..6f7b0b807 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from torch import Tensor +from torch import Tensor, is_grad_enabled from functorch import make_fx from torch.fx import immutable_collections import torch.utils._pytree as pytree @@ -8,7 +8,7 @@ from torch.nn.utils import _stateless from functorch._C import CompileCache from .decompositions import register_decomposition -from .partitioners import default_partition +from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules from .named_members_polyfill import _named_parameters, _named_buffers from typing import Callable, List, Dict, Any, Tuple, Optional from functools import wraps @@ -138,15 +138,18 @@ def create_aot_autograd_function( joint_forward_backward = create_joint_forward_backward(flat_fn) compiled_fw = None - compiled_bw = None + fw_module = None + bw_modules = [] num_outs = None + saved_value_names = None aot_decompositions = {**aot_autograd_decompositions, **decompositions} class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, num_outs + ctx.set_materialize_grads(False) + nonlocal compiled_fw, num_outs, fw_module, saved_value_names if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -165,65 +168,73 @@ def forward(ctx, *flat_tensor_args): fx_g = make_fx(joint_forward_backward, aot_decompositions)( *joint_inputs ) - fw_module, bw_module = partition_fn(fx_g, joint_inputs) - + # This means the forward and backward graphs are created based on the input fn + # However we need to take in grad_out for the saved intermediates as well. + fw_module, bw_module, saved_value_nodes = partition_fn(fx_g, joint_inputs) + saved_value_names = [node.name for node in saved_value_nodes] compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - if partition_fn is default_partition: - ctx.num_intermediate = len(fw_outs[num_outs:]) - ctx.num_inputs = len(flat_tensor_args) - to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out - ctx.fx_g = fx_g - ctx.save_for_backward(*to_be_saved) - ctx.fwd_graph = fw_module.code - ctx.bw_graph = bw_module.code - else: - nonlocal compiled_bw - bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] - compiled_bw = bw_compiler(bw_module, bw_args) - ctx.save_for_backward(*fw_outs[num_outs:]) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - if partition_fn is default_partition: - with torch.set_grad_enabled(grad_state): - out = flat_fn(*flat_tensor_args) - out = pytree.tree_map( - lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out - ) - ctx.num_intermediate = len(fw_outs[num_outs:]) - ctx.num_inputs = len(flat_tensor_args) - to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out - ctx.save_for_backward(*to_be_saved) - else: - ctx.save_for_backward(*fw_outs[num_outs:]) - return tuple(fw_outs[0:num_outs]) + + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + fw_outs[0:num_outs] + ctx.save_for_backward(*to_be_saved) + return tuple(fw_outs) @staticmethod @disable_torchdynamo def backward(ctx, *flat_grad_outs): - contiguous_args = [t.contiguous() for t in flat_grad_outs] - if compiled_bw is None: - assert partition_fn is default_partition + nonlocal fw_module, bw_modules, saved_value_names + intermediates = ctx.saved_tensors[:ctx.num_intermediate] + inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] + is_grad_enabled = torch.is_grad_enabled() + + if not is_grad_enabled: + input_flat_grad_outs = [] + for grad in flat_grad_outs: + if grad is not None: + input_flat_grad_outs.append(grad) with torch.set_grad_enabled(grad_state): - inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] - fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args) - fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:]) - assert fx_g.code == ctx.fx_g.code - f = aot_function(bw_module, bw_compiler, bw_compiler, partition_fn, aot_decompositions) - print("INPUTS----->", *ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) - print(bw_module.code) - out = f(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) - return out + fx_g_b = make_fx(joint_forward_backward, aot_decompositions)(inputs, input_flat_grad_outs) else: - if partition_fn is default_partition: - out = normalize_as_list(compiled_bw(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)) - else: - assert not torch.is_grad_enabled() - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) - return tuple(out) - - return CompiledFunction - + input_flat_grad_outs = flat_grad_outs + j_b = create_joint_forward_backward(fw_module) + with torch.set_grad_enabled(grad_state): + fx_g_b = make_fx(j_b, aot_decompositions)(inputs, input_flat_grad_outs) + + saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names) + assert len(saved_value_nodes) <= len(saved_value_names) + fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes) + bw_module_fn = None + for elem in bw_modules: + if elem.code == bw_module_b.code: + bw_module_fn = elem + if bw_module_fn is None: + bw_modules.append(bw_module_b) + bw_module_fn = bw_module_b + + f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions) + + if len(saved_values_new) != len(saved_value_names): + new_intermediates = [] + # Forward saves more intermediates than needed + assert len(saved_values_new) < len(saved_value_names) + j = 0 + for node in saved_values_new: + while node.name != saved_value_names[j]: + j+=1 + new_intermediates.append(intermediates[j]) + j+=1 + intermediates = new_intermediates + out = f(*intermediates, *input_flat_grad_outs) + return tuple(normalize_as_list(out)) + + def return_fn(*args, **kwargs): + out = CompiledFunction.apply(*args, **kwargs) + return out[0:num_outs] + return return_fn class _CompileCache(CompileCache): pass @@ -312,7 +323,7 @@ def rearrange(tensor_args, static_args, static_argnums): return args -KNOWN_TYPES = [torch.Tensor, int, str, float, bool] +KNOWN_TYPES = [torch.Tensor, int, str, float, bool, None] def aot_function( @@ -448,7 +459,9 @@ def returned_function(*args, **kwargs): hasher_type, *flat_args_for_cache, ) - + # print("fn_id: ", fn_id) + # print("size: ", compile_cache.size()) + # print("num_tensor_args: ", num_tensor_args) # Compile the function and save it in the cache if cached_res is None: # Save the args_spec for flat_tensor_args to unflatten while tracing @@ -473,7 +486,7 @@ def flat_fn(*flat_tensor_args): for i in flat_out: is_known_type = False for j in KNOWN_TYPES: - if isinstance(i, j): + if j is None or isinstance(i, j): is_known_type = True break if not is_known_type: @@ -495,7 +508,7 @@ def flat_fn(*flat_tensor_args): partition_fn, decompositions, grad_state=torch.is_grad_enabled(), - ).apply + ) cached_res = (compiled_fn, out_spec) # Save the compiled_fn in the cache @@ -635,7 +648,7 @@ def aot_function_simplified( partition_fn, decompositions, grad_state=torch.is_grad_enabled(), - ).apply + ) return compiled_fn @@ -657,4 +670,4 @@ def forward(self, *args, **kwargs): compiled_function = aot_function -compiled_module = aot_module +compiled_module = aot_module \ No newline at end of file diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 755502f9c..7ecae1aea 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -108,8 +108,23 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values): fwd_module = fx.GraphModule(joint_module, fwd_graph) bwd_module = fx.GraphModule(joint_module, bwd_graph) - return fwd_module, bwd_module + return fwd_module, bwd_module, saved_values +def _get_saved_values(new_module: fx.GraphModule, saved_value_names): + saved_values = [] + for node in new_module.graph.nodes: + if node.name in saved_value_names: + if 'tensor_meta' not in node.meta and node.op == 'call_function': + users = node.users + assert all(user.target == operator.getitem for user in users) + for user in users: + saved_values.append(user) + else: + saved_values.append(node) + + saved_values = list(saved_values) + + return saved_values def default_partition( joint_module: fx.GraphModule, _joint_inputs @@ -153,8 +168,8 @@ def default_partition( saved_values.append(user) else: saved_values.append(node) - saved_values = list(saved_values) + saved_values = list(saved_values) return _extract_fwd_bwd_modules(joint_module, saved_values) diff --git a/test/test_compile_cache.py b/test/test_compile_cache.py index 9ce7b7b4d..07301e4e2 100644 --- a/test/test_compile_cache.py +++ b/test/test_compile_cache.py @@ -16,6 +16,15 @@ def check(self, a, b, aot_fn, fn): res = aot_fn(a_clone, b_clone) res.sum().backward() + + # a_clone_2 = a.clone().detach().requires_grad_(True) + # b_clone_2 = b.clone().detach().requires_grad_(True) + # res = aot_fn(a_clone_2, b_clone_2) + # res.sum().backward() + + # res = aot_fn(a_clone_2, b_clone_2) + # res.sum().backward() + assert torch.allclose(res, ref) assert torch.allclose(a.grad, a_clone.grad) assert torch.allclose(b.grad, b_clone.grad) @@ -30,17 +39,16 @@ def fn(x, bias): aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type) a = torch.randn(10, 20, requires_grad=True) - b = torch.randn(20, requires_grad=True) + b = torch.randn(10, 20, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) a = torch.randn(10, 20, requires_grad=True) - b = torch.randn(10, 20, requires_grad=True) + b = torch.randn(10, 1, requires_grad=True) self.check(a, b, aot_autograd_fn, fn) end_num_recomps = functorch.compile.num_of_recompilations() - total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_compilation_for_dynamic_shape(self): def fn(x, bias): @@ -65,9 +73,9 @@ def fn(x, bias): total_recomps = end_num_recomps - start_num_recomps if hasher_type == "DynamicShapeHasher": - assert total_recomps == 1 + assert total_recomps == 11 elif hasher_type == "StaticShapeHasher": - assert total_recomps == 10 + assert total_recomps == 20 for s in range(10, 20): a = torch.randn(s, s, requires_grad=True) @@ -78,9 +86,9 @@ def fn(x, bias): total_recomps = end_num_recomps - start_num_recomps if hasher_type == "DynamicShapeHasher": - assert total_recomps == 2 + assert total_recomps == 22 elif hasher_type == "StaticShapeHasher": - assert total_recomps == 20 + assert total_recomps == 40 def test_global_cache_no_recompilations(self): def f(x, bias): @@ -97,7 +105,7 @@ def g(x, bias): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 1 + assert total_recomps == 2 def test_multiple_functions(self): def f(x, bias): @@ -122,7 +130,7 @@ def g(x, y): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 # Force recompilation for function f and check num of recompilations again a = torch.randn(10, 20, requires_grad=True) @@ -131,7 +139,7 @@ def g(x, y): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 3 + assert total_recomps == 6 def test_high_number_of_args(self): def f(*args): @@ -240,7 +248,7 @@ def fn(x, static_arg): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_static_arg_before_tensor_arg(self): def fn(static_arg, x): @@ -273,7 +281,7 @@ def check(a, b, aot_autograd_fn, fn): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_interleaved_static_args(self): def fn(static_arg1, x, static_arg2): @@ -308,7 +316,7 @@ def check(a, b, c, aot_autograd_fn, fn): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_dropout(self): def fn(x, prob): @@ -332,7 +340,7 @@ def fn(x, prob): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 3 def test_if_condition(self): def fn(x, state: bool): @@ -362,7 +370,7 @@ def fn(x, state: bool): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_custom(self): class Record: @@ -396,7 +404,7 @@ def fn(x, record): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_tuple(self): def fn(a_tuple, static_arg): @@ -440,7 +448,7 @@ def check(a_tuple, b, aot_autograd_fn, fn): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_tuple_with_first_arg_as_static(self): def fn(static_arg, a_tuple): @@ -484,7 +492,7 @@ def check(a, b_tuple, aot_autograd_fn, fn): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_dict(self): def fn(a_dict, static_arg): @@ -530,7 +538,7 @@ def check(a_dict, b, aot_autograd_fn, fn): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_dict_with_static_arg_before_dict(self): def fn(static_arg, a_dict): @@ -579,7 +587,7 @@ def check(a, b_dict, aot_autograd_fn, fn): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_tuple_static_args(self): def fn(x, tuple_static_arg): @@ -608,7 +616,7 @@ def fn(x, tuple_static_arg): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 2 + assert total_recomps == 4 def test_arg_none(self): def check(a, b, c, aot_autograd_fn, fn): @@ -677,7 +685,7 @@ def fn(a, b, c): end_num_recomps = functorch.compile.num_of_recompilations() total_recomps = end_num_recomps - start_num_recomps - assert total_recomps == 7 + assert total_recomps == 14 if __name__ == "__main__": diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index 7ec2868da..cd12a0485 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -246,17 +246,25 @@ def f(args, kwargs): def _outs_and_grads(fn, inps): outs = fn(*inps) - diff_outs = [] - for out in pytree.tree_flatten(outs)[0]: - if isinstance(out, torch.Tensor) and out.requires_grad: - diff_outs.append(out) - def full_reduce(outs): + + def get_diff_tensors(tensors): + diff_tensors = [] + for tensor in pytree.tree_flatten(tensors)[0]: + if isinstance(tensor, torch.Tensor) and tensor.requires_grad: + diff_tensors.append(tensor) + return diff_tensors + + def full_reduce(outs_): res = 0 - for out in outs: + for out in outs_: res=res+out.sum() return res - # print(inps) - grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True) + + diff_inps = get_diff_tensors(inps) + diff_outs = get_diff_tensors(outs) + assert len(diff_outs) > 0 + assert len(diff_inps) > 0 + grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps) return outs, grads def _outs_and_grads_and_grad_grads(fn, inps): @@ -271,23 +279,32 @@ def _outs_and_grads_and_grad_grads(fn, inps): diff_inps.append(inp) def full_reduce(outs): res = 0 - # print("entering full_reduce: ", type(outs)) for out in outs: res=res+out.sum() return res - print("diff_outs, diff_inps: ", diff_outs, diff_inps) + assert len(diff_outs) > 0 + assert len(diff_inps) > 0 grads = torch.autograd.grad(diff_outs, diff_inps, create_graph=True) - # print("grad call with: ", full_reduce(diff_outs), diff_inps) diff_grads = [] for grad_ in grads: if isinstance(grad_, torch.Tensor) and grad_.requires_grad: diff_grads.append(grad_) - # print("grad grad call with: ", grads, full_reduce(diff_grads), diff_inps) + assert len(diff_grads) > 0 grad_grads = torch.autograd.grad(diff_grads, diff_inps) return outs, grads, grad_grads class TestAOTAutograd(TestCase): def verify_aot_autograd(self, f, inp): + if isinstance(f, nn.Module): + compiled_f = aot_module(f, nop) + else: + compiled_f = aot_function(f, nop) + ref_out, ref_grad = _outs_and_grads(f, inp) + test_out, test_grad = _outs_and_grads(compiled_f, inp) + self.assertEqual(ref_out, test_out) + self.assertEqual(ref_grad, test_grad) + + def verify_aot_autograd_with_double_backward(self, f, inp): if isinstance(f, nn.Module): compiled_f = aot_module(f, nop) else: @@ -318,8 +335,9 @@ def f(a, b): def test_cube(self): def f(a): - return a ** 3 + return a * a * a inp = [torch.tensor(2.3, requires_grad=True)] + # self.verify_aot_autograd_with_double_backward(f, inp) self.verify_aot_autograd(f, inp) def test_no_grad_input_output(self): @@ -329,12 +347,14 @@ def f(a, b): inp_thunks = [lambda: torch.randn(5, requires_grad=True), lambda: torch.randn(5, requires_grad=False)] for inps in itertools.product(inp_thunks, repeat=2): inps = [i() for i in inps] - self.verify_aot_autograd(f, inps) + # ignore the case when both inputs don't require grad + if inps[0].requires_grad or inps[1].requires_grad: + self.verify_aot_autograd(f, inps) def test_inner_grad(self): def foo(x): y = torch.exp(x) - z = torch.autograd.grad(y, x) + z = torch.autograd.grad(y, x, create_graph=True) return z inps = [torch.randn((), requires_grad=True)] self.verify_aot_autograd(foo, inps) @@ -354,10 +374,8 @@ def assert_graph_empty(fx_g, _): f = aot_function(foo, nop, assert_graph_empty) with torch.set_grad_enabled(False): f(*inps) - self.assertEqual(graph_size, 2) with torch.set_grad_enabled(True): f(*inps) - self.assertTrue(graph_size > 2) self.assertEqual(num_of_recompilations() - start_recompilations, 2) def test_output_dict(self): @@ -418,6 +436,7 @@ class TestEagerFusionOpInfo(TestCase): xfail('trapz'), skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes? skip('nn.functional.margin_ranking_loss'), # seems flaky + skip('linalg.det'), # fails }) def test_aot_autograd_exhaustive(self, device, dtype, op): def f(args, kwargs): @@ -499,7 +518,7 @@ def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition): fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=partitioner, - decompositions=default_decompositions)(*inps) + decompositions=default_decompositions)(*inps).sum().backward() return (fw_graph_cell[0], bw_graph_cell[0]) @@ -563,8 +582,8 @@ def f(x, mod_weight, mod_bias): fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias], partitioner=default_partition) - self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) - self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) + self.assertEqual(get_num_ins_outs(fw_graph), (3, 7)) + self.assertEqual(get_num_ins_outs(bw_graph), (6, 6)) @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner(self): From c9732a8014ba25f0957cd136bff481b9297d738f Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 21 Jun 2022 16:02:40 +0000 Subject: [PATCH 7/8] Update on "Separate forward and backwad compilation for default partition" Test Plan: Existing tests should pass [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 31 +++++++++++++++++-------------- test/test_pythonkey.py | 6 +++--- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 6f7b0b807..c16cc72fb 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -198,36 +198,39 @@ def backward(ctx, *flat_grad_outs): input_flat_grad_outs.append(grad) with torch.set_grad_enabled(grad_state): fx_g_b = make_fx(joint_forward_backward, aot_decompositions)(inputs, input_flat_grad_outs) + saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names) + assert len(saved_value_nodes) <= len(saved_value_names) + fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes) + if len(saved_values_new) != len(saved_value_names): + new_intermediates = [] + # Forward saves more intermediates than needed + assert len(saved_values_new) < len(saved_value_names) + j = 0 + for node in saved_values_new: + while node.name != saved_value_names[j]: + j+=1 + new_intermediates.append(intermediates[j]) + j+=1 + intermediates = new_intermediates else: input_flat_grad_outs = flat_grad_outs j_b = create_joint_forward_backward(fw_module) with torch.set_grad_enabled(grad_state): fx_g_b = make_fx(j_b, aot_decompositions)(inputs, input_flat_grad_outs) + fw_module_b, bw_module_b, _ = partition_fn(fx_g_b, (inputs, input_flat_grad_outs)) + - saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names) - assert len(saved_value_nodes) <= len(saved_value_names) - fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes) bw_module_fn = None for elem in bw_modules: if elem.code == bw_module_b.code: bw_module_fn = elem + break if bw_module_fn is None: bw_modules.append(bw_module_b) bw_module_fn = bw_module_b f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions) - if len(saved_values_new) != len(saved_value_names): - new_intermediates = [] - # Forward saves more intermediates than needed - assert len(saved_values_new) < len(saved_value_names) - j = 0 - for node in saved_values_new: - while node.name != saved_value_names[j]: - j+=1 - new_intermediates.append(intermediates[j]) - j+=1 - intermediates = new_intermediates out = f(*intermediates, *input_flat_grad_outs) return tuple(normalize_as_list(out)) diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index cd12a0485..faea6778b 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -335,10 +335,10 @@ def f(a, b): def test_cube(self): def f(a): - return a * a * a + return a ** 3 inp = [torch.tensor(2.3, requires_grad=True)] - # self.verify_aot_autograd_with_double_backward(f, inp) - self.verify_aot_autograd(f, inp) + self.verify_aot_autograd_with_double_backward(f, inp) + # self.verify_aot_autograd(f, inp) def test_no_grad_input_output(self): def f(a, b): From d1cf3e8fda1b9c1ed562c24a05ff5f6c7c0376fd Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 12 Jul 2022 17:30:15 +0000 Subject: [PATCH 8/8] Update on "Separate forward and backwad compilation for default partition" Test Plan: Existing tests should pass [ghstack-poisoned] --- functorch/_src/aot_autograd.py | 48 ++++++++++++++++++++++++---------- functorch/_src/partitioners.py | 29 ++++++++++++++++++++ test/test_pythonkey.py | 9 ++++--- 3 files changed, 68 insertions(+), 18 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index c16cc72fb..614b92925 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -8,7 +8,7 @@ from torch.nn.utils import _stateless from functorch._C import CompileCache from .decompositions import register_decomposition -from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules +from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules, _extract_fwd_bwd_modules_db from .named_members_polyfill import _named_parameters, _named_buffers from typing import Callable, List, Dict, Any, Tuple, Optional from functools import wraps @@ -138,8 +138,8 @@ def create_aot_autograd_function( joint_forward_backward = create_joint_forward_backward(flat_fn) compiled_fw = None - fw_module = None bw_modules = [] + fw_module = None num_outs = None saved_value_names = None aot_decompositions = {**aot_autograd_decompositions, **decompositions} @@ -149,7 +149,7 @@ class CompiledFunction(torch.autograd.Function): @disable_torchdynamo def forward(ctx, *flat_tensor_args): ctx.set_materialize_grads(False) - nonlocal compiled_fw, num_outs, fw_module, saved_value_names + nonlocal compiled_fw, num_outs, saved_value_names, fw_module if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -177,6 +177,7 @@ def forward(ctx, *flat_tensor_args): else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) + # print(fw_module.code) ctx.num_intermediate = len(fw_outs[num_outs:]) ctx.num_inputs = len(flat_tensor_args) to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + fw_outs[0:num_outs] @@ -186,11 +187,10 @@ def forward(ctx, *flat_tensor_args): @staticmethod @disable_torchdynamo def backward(ctx, *flat_grad_outs): - nonlocal fw_module, bw_modules, saved_value_names + nonlocal bw_modules, saved_value_names, fw_module, num_outs intermediates = ctx.saved_tensors[:ctx.num_intermediate] inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] is_grad_enabled = torch.is_grad_enabled() - if not is_grad_enabled: input_flat_grad_outs = [] for grad in flat_grad_outs: @@ -212,14 +212,35 @@ def backward(ctx, *flat_grad_outs): new_intermediates.append(intermediates[j]) j+=1 intermediates = new_intermediates - else: - input_flat_grad_outs = flat_grad_outs - j_b = create_joint_forward_backward(fw_module) - with torch.set_grad_enabled(grad_state): - fx_g_b = make_fx(j_b, aot_decompositions)(inputs, input_flat_grad_outs) - fw_module_b, bw_module_b, _ = partition_fn(fx_g_b, (inputs, input_flat_grad_outs)) - - + # else: + # input_flat_grad_outs = flat_grad_outs + # # create_joint_forward_backward takes inputs and cotangents as inps + # # inps: inputs, cotangents: flat_grad_outs + # j_b = create_joint_forward_backward(ctx.fw_module) + # # setting grad is not needed + # with torch.set_grad_enabled(grad_state): + # fx_g_b = make_fx(j_b, aot_decompositions)(inputs, input_flat_grad_outs) + # saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names) + # # print(saved_value_nodes) + # # print(saved_value_names) + # # assert len(saved_value_nodes) == len(saved_value_names) + # fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules_db(fx_g_b, saved_value_nodes) + # # print(fx_g_b.code, ctx.fw_module.code, fw_module_b.code, bw_module_b.code) + # # assert fw_module_b.code == fw_module.code + # # print(len(sew), len(saved_value_names)) + # if len(saved_values_new) != len(saved_value_names): + # new_intermediates = [] + # # Forward saves more intermediates than needed + # assert len(saved_values_new) < len(saved_value_names) + # for node in saved_values_new: + # j = 0 + # while node.name != saved_value_names[j]: + # j+=1 + # new_intermediates.append(intermediates[j]) + # j+=1 + # intermediates = new_intermediates + + # This is needed because aot function caching uses function id right now bw_module_fn = None for elem in bw_modules: if elem.code == bw_module_b.code: @@ -230,7 +251,6 @@ def backward(ctx, *flat_grad_outs): bw_module_fn = bw_module_b f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions) - out = f(*intermediates, *input_flat_grad_outs) return tuple(normalize_as_list(out)) diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 7ecae1aea..07860db7e 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -110,6 +110,35 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values): bwd_module = fx.GraphModule(joint_module, bwd_graph) return fwd_module, bwd_module, saved_values +def _extract_fwd_bwd_modules_db(joint_module: fx.GraphModule, saved_values): + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module) + print("FWD OUTS: ", fwd_outputs) + print("BWD OUTS: ", bwd_outputs) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes)) + print("primal_inputs: ", primal_inputs) + print("tangent_inputs: ", tangent_inputs) + # Construct the forward module + fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) + bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs) + + # This is to filter out saved values that don't actually end up being used by the backwards pass + for node in bwd_graph.nodes: + if node.op == 'placeholder' and not node.users: + for saved_value in saved_values: + if saved_value.name == node.name: + saved_values.remove(saved_value) + break + + # Now, we re-generate the fwd/bwd graphs. + # NB: This might increase compilation time, but I doubt it matters + fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) + bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs) + + fwd_module = fx.GraphModule(joint_module, fwd_graph) + bwd_module = fx.GraphModule(joint_module, bwd_graph) + return fwd_module, bwd_module, saved_values + def _get_saved_values(new_module: fx.GraphModule, saved_value_names): saved_values = [] for node in new_module.graph.nodes: diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index faea6778b..d87ce02c8 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -308,7 +308,7 @@ def verify_aot_autograd_with_double_backward(self, f, inp): if isinstance(f, nn.Module): compiled_f = aot_module(f, nop) else: - compiled_f = aot_function(f, nop) + compiled_f = aot_function(f, nop, partition_fn=min_cut_rematerialization_partition) ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp) test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) @@ -333,9 +333,9 @@ def f(a, b): inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) - def test_cube(self): + def test_sin_bla(self): def f(a): - return a ** 3 + return torch.sin(a) inp = [torch.tensor(2.3, requires_grad=True)] self.verify_aot_autograd_with_double_backward(f, inp) # self.verify_aot_autograd(f, inp) @@ -436,7 +436,7 @@ class TestEagerFusionOpInfo(TestCase): xfail('trapz'), skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes? skip('nn.functional.margin_ranking_loss'), # seems flaky - skip('linalg.det'), # fails + # skip('linalg.det'), # fails }) def test_aot_autograd_exhaustive(self, device, dtype, op): def f(args, kwargs): @@ -599,6 +599,7 @@ def f(a, b, c, d): return x.cos().cos() fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True) for _ in range(4)]) + self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))