Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchdynamo with aot_autograd_speedup_strategy has increased memory usage and long overhead on ResNet50 model #93751

Closed
xwang233 opened this issue May 3, 2022 · 12 comments

Comments

@xwang233
Copy link
Collaborator

xwang233 commented May 3, 2022

python script (please correct me if I'm using the torchdynamo API wrong)

dynamo-test.py

import contextlib
import traceback
import time

import torch
import torchvision
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

N_WARMUP = 100
N_BENCH = 100

def bench(batch_size, use_dynamo):
    model = torchvision.models.resnet50().cuda()
    x = torch.randn(batch_size, 3, 224, 224, dtype=torch.float, device='cuda')

    train_context = torchdynamo.optimize(aot_autograd_speedup_strategy) if use_dynamo is True else contextlib.nullcontext()

    torch.cuda.synchronize()
    t0  = time.time()

    with train_context:
        for _ in range(N_WARMUP):
            out = model(x)
            out.sum().backward()

        torch.cuda.synchronize()
        t1 = time.time()

        for _ in range(N_BENCH):
            out = model(x)
            out.sum().backward()

        torch.cuda.synchronize()
        t2 = time.time()

    print('Training img/s (larger better):', batch_size / ((t2 - t1) / N_BENCH))
    print('Total time incl. overhead (smaller better):', t2 - t0)
    print()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-b', '--batch_size', type=int, default=16)
    parser.add_argument('--use_dynamo', action='store_true', default=False)
    args = parser.parse_args()
    print(args)

    bench(args.batch_size, args.use_dynamo)

bash script

#!/bin/bash
python dynamo-test.py -b 16
python dynamo-test.py -b 16 --use_dynamo
python dynamo-test.py -b 64
python dynamo-test.py -b 64 --use_dynamo
python dynamo-test.py -b 128
python dynamo-test.py -b 128 --use_dynamo

Tested with V100 16GB on
f6bbecf ,
pytorch/vision@104073c,
torchdynamo @ 0d59ce9,
pytorch/functorch@ac0fdf1,
cuda 11.6 update 1, cudnn 8.3.3.40

results

00:00:15 Namespace(batch_size=16, use_dynamo=False)
00:00:15 Training img/s (larger better): 306.0273543197808
00:00:15 Total time incl. overhead (smaller better): 11.840148687362671
00:00:15 
00:00:41 Namespace(batch_size=16, use_dynamo=True)
00:00:41 Training img/s (larger better): 307.23167670348334
00:00:41 Total time incl. overhead (smaller better): 21.702965021133423
00:00:41 
00:01:25 Namespace(batch_size=64, use_dynamo=False)
00:01:25 Training img/s (larger better): 338.7971538235037
00:01:25 Total time incl. overhead (smaller better): 39.105363607406616
00:01:25 
00:02:18 Namespace(batch_size=64, use_dynamo=True)
00:02:18 Training img/s (larger better): 344.55961021952044
00:02:18 Total time incl. overhead (smaller better): 48.27671003341675
00:02:18 
00:03:37 Namespace(batch_size=128, use_dynamo=False)
00:03:37 Training img/s (larger better): 350.9537424835023
00:03:37 Total time incl. overhead (smaller better): 74.2812168598175
00:03:37 
00:03:43 ERROR FROM offset=100 filename /opt/pytorch/vision/torchvision/models/resnet.py 156 RuntimeError
00:03:43 ERROR FROM offset=66 filename /opt/pytorch/vision/torchvision/models/resnet.py 273 RuntimeError
00:03:43 ERROR FROM offset=6 filename /opt/pytorch/vision/torchvision/models/resnet.py 283 RuntimeError
00:03:43 ========== TorchDynamo Stack Trace ==========
00:03:43 Traceback (most recent call last):
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/convert_frame.py", line 170, in _convert_frame_assert
00:03:43     code = transform_code_object(frame.f_code, transform)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/bytecode_transformation.py", line 338, in transform_code_object
00:03:43     transformations(instructions, code_options)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/convert_frame.py", line 146, in transform
00:03:43     tracer.run()
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 278, in run
00:03:43     and self.step()
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 256, in step
00:03:43     getattr(self, inst.opname)(inst)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 142, in wrapper
00:03:43     return inner_fn(self, inst)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 556, in CALL_FUNCTION
00:03:43     self.call_function(fn, args, {})
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43     self.push(fn.call_function(self, args, kwargs))
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/functions.py", line 182, in call_function
00:03:43     return super().call_function(tx, args, kwargs)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/functions.py", line 64, in call_function
00:03:43     return tx.inline_user_function_return(
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 227, in inline_user_function_return
00:03:43     result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1204, in inline_call
00:03:43     return cls.inline_call_(parent, func, args, kwargs)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1246, in inline_call_
00:03:43     tracer.run()
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 278, in run
00:03:43     and self.step()
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 256, in step
00:03:43     getattr(self, inst.opname)(inst)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 142, in wrapper
00:03:43     return inner_fn(self, inst)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 556, in CALL_FUNCTION
00:03:43     self.call_function(fn, args, {})
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43     self.push(fn.call_function(self, args, kwargs))
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 158, in call_function
00:03:43     tx.call_function(
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43     self.push(fn.call_function(self, args, kwargs))
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 185, in call_function
00:03:43     return tx.inline_user_function_return(
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 227, in inline_user_function_return
00:03:43     result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1204, in inline_call
00:03:43     return cls.inline_call_(parent, func, args, kwargs)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 1246, in inline_call_
00:03:43     tracer.run()
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 278, in run
00:03:43     and self.step()
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 256, in step
00:03:43     getattr(self, inst.opname)(inst)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 142, in wrapper
00:03:43     return inner_fn(self, inst)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 556, in CALL_FUNCTION
00:03:43     self.call_function(fn, args, {})
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43     self.push(fn.call_function(self, args, kwargs))
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 158, in call_function
00:03:43     tx.call_function(
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 195, in call_function
00:03:43     self.push(fn.call_function(self, args, kwargs))
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 172, in call_function
00:03:43     return variables.TensorVariable.create(
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/variables/tensor.py", line 94, in create
00:03:43     proxy.node.meta["example_value"] = clone_tensor(example_value)
00:03:43   File "/opt/conda/lib/python3.8/site-packages/torchdynamo/utils.py", line 151, in clone_tensor
00:03:43     y = x.clone().requires_grad_(x.requires_grad)
00:03:43 RuntimeError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 15.78 GiB total capacity; 14.54 GiB already allocated; 19.75 MiB free; 14.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
00:03:43 ========== Exception (above) while processing ==========
00:03:43   File "dynamo-test.py", line 50, in <module>
00:03:43     bench(args.batch_size, args.use_dynamo)
00:03:43   File "dynamo-test.py", line 24, in bench
00:03:43     out = model(x)
00:03:43   File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
00:03:43     return forward_call(*input, **kwargs)
00:03:43   File "/opt/pytorch/vision/torchvision/models/resnet.py", line 282, in forward
00:03:43     def forward(self, x: Tensor) -> Tensor:
00:03:43 ========== End debug info ==========
00:03:43 Namespace(batch_size=128, use_dynamo=True)
00:03:43 Traceback (most recent call last):
00:03:43   File "dynamo-test.py", line 50, in <module>
00:03:43     bench(args.batch_size, args.use_dynamo)
00:03:43   File "dynamo-test.py", line 24, in bench
00:03:43     out = model(x)
00:03:43   File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
00:03:43     return forward_call(*input, **kwargs)
00:03:43   File "/opt/pytorch/vision/torchvision/models/resnet.py", line 283, in forward
00:03:43     return self._forward_impl(x)
00:03:43   File "/opt/pytorch/vision/torchvision/models/resnet.py", line 266, in _forward_impl
00:03:43     x = self.conv1(x)
00:03:43   File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
00:03:43     return forward_call(*input, **kwargs)
00:03:43   File "/opt/pytorch/pytorch/torch/nn/modules/conv.py", line 447, in forward
00:03:43     return self._conv_forward(input, self.weight, self.bias)
00:03:43   File "/opt/pytorch/pytorch/torch/nn/modules/conv.py", line 443, in _conv_forward
00:03:43     return F.conv2d(input, weight, bias, self.stride,
00:03:43 RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 15.78 GiB total capacity; 14.52 GiB already allocated; 19.75 MiB free; 14.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It is seen that torchdynamo with aot_autograd_speedup_strategy has increased memory usage and longer overhead on ResNet50 model than the eager mode.

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @csarofeen @ptrblck @jjsjann123 @kevinstephano @jansel

@jansel
Copy link
Contributor

jansel commented May 3, 2022

cc @anijain2305 who was looking into a similar issue.

@anijain2305
Copy link
Contributor

Looking into this.

I have seen that AOT Autograd can increase the memory footprint while tracing (maybe we are duplicating the memory somewhere, I have to check). But failing for batch size of 256 sounds bad. Thanks for pointing it out.

The performance drop for working batch sizes is unexpected. I will update within a couple of days.

@Chillee
Copy link
Contributor

Chillee commented May 3, 2022

Currently, torchdynamo.optimize(aot_autograd_speedup_strategy) doesn't actually turn on NVFuser (we should fix this - cc: @anijain2305 ). We need to add a with torch.jit.fuser("fuser2") alongside it.

Turning that on results in significant speedups over eager (on my machine, A100 40 GB)

image

As for overhead of first compilation, that's not something we've substantially investigated/measured in the past - we'll look into that.

Memory overhead could also come from a couple different places, also something we need to look into.

Updated script with NVFuser
import contextlib
import time


import torch
import torchvision
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

N_WARMUP = 10
N_BENCH = 10

@torchdynamo.skip
def bench(batch_size, use_dynamo):
    model = torchvision.models.resnet50().cuda()
    x = torch.randn(batch_size, 3, 224, 224, dtype=torch.float, device='cuda')

    train_context = torchdynamo.optimize(aot_autograd_speedup_strategy) if use_dynamo is True else contextlib.nullcontext()

    torch.cuda.synchronize()
    t0  = time.time()

    with train_context:
        for i in range(N_WARMUP):
            out = model(x)
            out.sum().backward()

        torch.cuda.synchronize()
        t1 = time.time()

        for _ in range(N_BENCH):
            out = model(x)
            out.sum().backward()

        torch.cuda.synchronize()
        t2 = time.time()

    print('Training img/s (larger better):', batch_size / ((t2 - t1) / N_BENCH))
    print('Total time incl. overhead (smaller better):', t2 - t0)
    print()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-b', '--batch_size', type=int, default=16)
    parser.add_argument('--use_dynamo', action='store_true', default=False)
    parser.add_argument('--use_nvfuser', action='store_true', default=False)
    args = parser.parse_args()
    print(args)
    if args.use_nvfuser:
        fuser = 'fuser2'
    else:
        fuser = 'fuser1'
    with torch.jit.fuser(fuser):
        bench(args.batch_size, args.use_dynamo)

@xwang233
Copy link
Collaborator Author

xwang233 commented May 3, 2022

Thanks for the update @Chillee ! Just want to make sure, torchdynamo.optimize(aot_autograd_speedup_strategy) is supposed to use together with nvfuser but not standalone on an eager mode model. Is that correct?

If so, do users need to add a nvfuser context with torch.jit.fuser("fuser2") in their script or is it going to be turned on automatically with torchdynamo.optimize(aot_autograd_speedup_strategy) in the future?

@Chillee
Copy link
Contributor

Chillee commented May 3, 2022

Just want to make sure, torchdynamo.optimize(aot_autograd_speedup_strategy) is supposed to use together with nvfuser but not standalone on an eager mode model. Is that correct?

Yes, it uses a rematerialization algorithm that needs to know what things are fusible (and thus, is currently tuned against nvfuser).

If so, do users need to add a nvfuser context with torch.jit.fuser("fuser2") in their script or is it going to be turned on automatically with torchdynamo.optimize(aot_autograd_speedup_strategy) in the future?

I think it totally makes sense to turn it on automatically, @anijain2305 is looking into that.

@csarofeen
Copy link
Contributor

FYI: It still uses significantly more memory on networks with nvFuser. We're seeing a situation where 3/4 of eager mode's batch size still doesn't fit, but 1/2 does.

@Chillee
Copy link
Contributor

Chillee commented May 5, 2022

@csarofeen Yep, the memory overhead (nor investigating where the first compilation overhead come from!) have not been solved yet. There are some cases where we're careless with the memory and keeping references longer than they should be.

Working on those now.

@Chillee
Copy link
Contributor

Chillee commented May 5, 2022

Some investigations. So, the experiments here are all run with a batch size of 256, and I'll primarily be reporting on peak memory usage (as measured by torch.cuda.max_memory_allocated() after the benchmarking has finished).

As a baseline, eager-mode (no torchdynamo, etc.) reaches 22 GB of peak memory usage. On the other hand, torchdynamo + AOTAutograd + nvfuser run in 39 GB of peak memory usage - not good!

  1. @anijain2305 discovered that one of the predominant sources of memory overhead here is coming from torchdynamo. There are some situations where we're doing analysis passes that substantially increase memory usage. If we simply remove torchdynamo from the picture (in this instance, by using AOTAutograd by itself without torchdynamo), we reduce memory usage down to 28 GB.

  2. There appear to be 2 more sources of increased memory usage compared to eager mode. The first one is that somewhere in the torchscript/profiling executor/nvfuser pipeline, we're using extra memory. Turning off the profiling executor (or using NNC instead) lowers peak memory to 22 GB. I'm not sure what it is, needs more investigation.

  3. Another issue is with how AOTAutograd currently compiles FX graphs - namely, when calling a function in Python, all inputs to a function will be kept alive for the duration of the function. So, for example, if you have

def f(a, b):
   y = a + b
   return y.cos()

a is going to be kept alive for the duration of f (even when it's no longer needed). This is particularly problematic during the backwards pass, where we pass in all of the activations as inputs to the backwards pass. Our solution is to rewrite the function as

def f(args_lst):
    a, b = args_lst
    args_lst.clear()
    y = a + b
    return y.cos()

Which I've done here: pytorch/functorch#779

Unfortunately, this doesn't seem to reduce memory with Torchscript - I need to disable the lowering to Torchscript (and run it as a FX graph) in order to reduce memory. However, doing so reduces the peak memory usage to 22 GB, same as eager.

Next steps:

  1. Investigate Torchdynamo's extra memory overhead (cc: @anijain2305).
  2. Figure out why Torchscript prevents the AOTAutograd modification from removing memory usage.

cc: @csarofeen @xwang233

For now, if you want to benchmark more while reducing memory usage (while still using NVFuser), the easiest thing to try would be AOTAutograd without Dynamo.

Updated script with AOTAutograd
import contextlib
import time
from functorch.compile import memory_efficient_fusion


import torch
import torchvision
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

N_WARMUP = 10
N_BENCH = 10

@torchdynamo.skip
def bench(batch_size, use_dynamo):
    model = torchvision.models.resnet50().cuda()
    if use_dynamo:
        model = memory_efficient_fusion(model)
    x = torch.randn(batch_size, 3, 224, 224, dtype=torch.float, device='cuda')

    # just using aotautograd for now
    # train_context = torchdynamo.optimize(aot_autograd_speedup_strategy) if use_dynamo is True else contextlib.nullcontext()
    train_context = contextlib.nullcontext()

    torch.cuda.synchronize()
    t0  = time.time()

    with train_context:
        for i in range(N_WARMUP):
            out = model(x)
            out.sum().backward()

        torch.cuda.synchronize()
        t1 = time.time()

        for _ in range(N_BENCH):
            out = model(x)
            out.sum().backward()
            print(torch.cuda.max_memory_allocated()/1e9)
            torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        t2 = time.time()

    print('Training img/s (larger better):', batch_size / ((t2 - t1) / N_BENCH))
    print('Total time incl. overhead (smaller better):', t2 - t0)
    print()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-b', '--batch_size', type=int, default=16)
    parser.add_argument('--use_dynamo', action='store_true', default=False)
    parser.add_argument('--use_nvfuser', action='store_true', default=False)
    args = parser.parse_args()
    print(args)
    if args.use_nvfuser:
        fuser = 'fuser2'
    else:
        fuser = 'fuser1'
    with torch.jit.fuser(fuser):
        bench(args.batch_size, args.use_dynamo)
        print(torch.cuda.max_memory_allocated()/1e9)

@jansel
Copy link
Contributor

jansel commented May 5, 2022

One obvious source of memory overhead from TorchDynamo is config.dynamic_propagation=True. With this mode, TorchDynamo will create an example_value copy of every tensor and nn.Module, in order to have accurate python type/dtype/device/shape information. This could easily double memory usage in the worst case.

This approach is nice, in that in it highly accurate and trivial to implement -- however it is very wasteful in the memory department.

We should rewrite dynamic_propagation to use meta tensors (and fall back to real tensors for ops where meta tensors aren't implemented).

It is a very possible there are other sources of memory overhead as well, I think @anijain2305 is looking into one.

Most things should work if you disable dynamic_propagation. The exceptions are it allows constant inlining of tensor properties (dtype/device/ndim/shape/contiguous/layout/etc) and handling of ops that return lists/tuples/etc.

@anijain2305
Copy link
Contributor

@jansel After my quick investigation, I saw two sources of memory increase

  1. dynamic_propagation as you mentioned
  2. normalize_ir - Basically, functionalization in Dynamo. I did not dive deep enough to find the source of memory increase here.

You already covered the first one in detail. I was thinking about using CPU device for storing the cloned tensors to free up GPU memory. But, fake/meta tensors sounds a better long-term solution.

Second one is specific to AOT Autograd. Hopefully, this is temporary because we are trying to move to functionalize at the dispatcher level.

Apart from this, I have a couple of small examples where Dynamo is not releasing/deleting the tensor when the tensor goes out of scope. At the moment, it is unclear if they are real issues or just badly setup test.

@csarofeen
Copy link
Contributor

csarofeen commented May 6, 2022

@Chillee I tried AOTAutograd without dynamo but seems there's an error that's likely associated with AMP. I manually disabled the torchscript amp pass but still hitting this error.

PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth TIMM_BENCHMARK_ENABLE_TORCHDYNAMO=1 python benchmark.py --bench train --model resnet50 --img-size 224 -b 128 --amp --fuser nvfuser
WARNING: Overriding precision to 'amp' since --amp flag set.
Benchmarking in amp precision. NCHW layout. torchscript disabled
Model resnet50 created, param count: 25557032
Running train benchmark on resnet50 for 40 steps w/ input size (3, 224, 224) and batch size 128.
WARNING: "The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<eval_with_key>.4", line 20, in forward
    _tensor_constant53 = self._tensor_constant53
    where = torch.ops.aten.where(le, _tensor_constant53, div);  le = _tensor_constant53 = div = None
    native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward(where, convolution_52, primals_158, primals_319, primals_320, getitem_159, getitem_160, True, 1e-05, [True, True, True]);  convolution_52 = primals_158 = primals_319 = primals_320 = getitem_159 = getitem_160 = None
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getitem_161 = native_batch_norm_backward[0]
    getitem_162 = native_batch_norm_backward[1]
RuntimeError: Expected grad_output to have type Half but got Float
" while running benchmark. Reducing batch size to 128 for retry.
Traceback (most recent call last):
  File "benchmark.py", line 646, in <module>
    main()
  File "benchmark.py", line 630, in main
    results = benchmark(args)
  File "benchmark.py", line 573, in benchmark
    run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
  File "benchmark.py", line 519, in _try_run
    results = bench.run()
  File "benchmark.py", line 392, in run
    _step()
  File "benchmark.py", line 371, in _step
    self.loss(output, target).backward()
  File "/opt/conda/lib/python3.8/site-packages/functorch/_src/monkey_patching.py", line 97, in _backward
    return _old_backward(*args, **kwargs)
  File "/opt/pytorch/pytorch/torch/_tensor.py", line 399, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/pytorch/pytorch/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/pytorch/pytorch/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.8/site-packages/torchdynamo/eval_frame.py", line 65, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/functorch/_src/aot_autograd.py", line 182, in backward
    out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
  File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1129, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<eval_with_key>.4", line 20, in forward
    _tensor_constant53 = self._tensor_constant53
    where = torch.ops.aten.where(le, _tensor_constant53, div);  le = _tensor_constant53 = div = None
    native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward(where, convolution_52, primals_158, primals_319, primals_320, getitem_159, getitem_160, True, 1e-05, [True, True, True]);  convolution_52 = primals_158 = primals_319 = primals_320 = getitem_159 = getitem_160 = None
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getitem_161 = native_batch_norm_backward[0]
    getitem_162 = native_batch_norm_backward[1]
RuntimeError: Expected grad_output to have type Half but got Float

The good news is in FP32 it doesn't OOM like with Dynamo.

@csarofeen
Copy link
Contributor

Also with just AOT I'm seeing on resnet50 with V100 in FP32:
Eager: 350ms/step
TorchScript+nvFuser: 350ms/step
AOT: 331ms/step

@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@ngimel ngimel closed this as completed Feb 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants