pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.67k stars 21.91k forks source link

TorchDynamo fails to trace the graph when custom op is being used #87491

Open parthmannan opened 1 year ago

parthmannan commented 1 year ago

🐛 Describe the bug

TorchDynamo does not trace the graph when using custom APEX LayerNorm and this results in no fusion of other elementwise operations. The issue can be reproduced using the code below.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
import torchdynamo

tensor_dtype = torch.float16

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(768, 768, bias=False)
        self.fc2 = nn.Linear(768, 2048, bias=False)
        self.fast_layer_norm = FastLayerNorm(768)
        self.bias = Parameter(torch.rand(768))

    def forward(self, x, residuals):
        out      = self.fc1(x)
        out      = out + self.bias
        out      = F.dropout(out, p=0.1, training=True)
        ln_input = out + residuals
        ln_out = self.fast_layer_norm(ln_input)
        out1     = self.fc2(ln_out)
        return out1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = Net()
net.cuda()
net = net.half()

input_shapes    = [(512, 64, 768),
                   (512, 64, 768)]

def generate_io_tensor(net, input_shapes):
    input_tensors = []

    for shape in input_shapes:
        tensor = torch.rand(shape, dtype=torch.float16, requires_grad=True, device='cuda')
        input_tensors.append(tensor)

    target_tensor_size = net(*input_tensors).size()
    target_tensor = torch.rand(target_tensor_size, dtype=torch.float16, device='cuda')

    return input_tensors, target_tensor

network_fn = torchdynamo.optimize("aot_nvfuser")(net)

bench_iters = 10
for idx in range(bench_iters):

    input_tensors, target_tensor = generate_io_tensor(net, input_shapes)
    for tensor in input_tensors:
        tensor.grad = None
    network_fn.zero_grad(set_to_none=True)

    outputs = network_fn(*input_tensors)
    outputs.backward(target_tensor)

Using AOT_FX_GRAPHS=1 I see the following output

====== Forward (only) graph ======
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f16[512, 64, 768], arg1_1: f16[768], arg2_1: f16[768]):
        # Module stack: {}, File: /opt/conda/lib/python3.8/site-packages/apex/contrib/layer_norm/layer_norm.py:15, code: xmat = x.view((-1, hidden_size))
        view: f16[32768, 768] = torch.ops.aten.view.default(arg0_1, [-1, 768])
        return (view, arg1_1, arg2_1, arg0_1)

The expected output would have the graph and could fuse the remaining ops using NvFuser.

Versions

Collecting environment information...
PyTorch version: 1.14.0a0+gite85dbcc
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31

Python version: 3.8.13 (default, Mar 28 2022, 11:38:47)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 515.65.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] bert-pytorch==0.0.1a4
[pip3] clip-anytorch==2.5.0
[pip3] CoCa-pytorch==0.0.6
[pip3] dalle2-pytorch==1.10.5
[pip3] ema-pytorch==0.0.10
[pip3] functorch==1.14.0a0+408bcf1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.5
[pip3] pytorch-transformers==1.2.0
[pip3] pytorch-warmup==0.1.0
[pip3] rotary-embedding-torch==0.1.5
[pip3] torch==1.14.0a0+gite85dbcc
[pip3] torch-fidelity==0.3.0
[pip3] torch-struct==0.5
[pip3] torchdynamo==1.14.0.dev0
[pip3] torchmetrics==0.10.0
[pip3] torchrec-nightly==2022.10.14
[pip3] torchtext==0.14.0a0+4570a56
[pip3] torchvision==0.15.0a0+f467349
[pip3] torchx-nightly==2022.10.17
[pip3] vector-quantize-pytorch==0.9.2
[conda] bert-pytorch              0.0.1a4                   dev_0    <develop>
[conda] clip-anytorch             2.5.0                    pypi_0    pypi
[conda] coca-pytorch              0.0.6                    pypi_0    pypi
[conda] dalle2-pytorch            1.10.5                   pypi_0    pypi
[conda] ema-pytorch               0.0.10                   pypi_0    pypi
[conda] functorch                 1.14.0a0+408bcf1          pypi_0    pypi
[conda] mkl                       2019.1                      144  
[conda] mkl-include               2019.1                      144  
[conda] nomkl                     3.0                           0  
[conda] numpy                     1.21.2                   pypi_0    pypi
[conda] numpy-base                1.21.5           py38hb8be1f0_2  
[conda] pytorch-transformers      1.2.0                    pypi_0    pypi
[conda] pytorch-warmup            0.1.0                    pypi_0    pypi
[conda] rotary-embedding-torch    0.1.5                    pypi_0    pypi
[conda] torch                     1.14.0a0+gite85dbcc           dev_0    <develop>
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torch-struct              0.5                      pypi_0    pypi
[conda] torchdynamo               1.14.0.dev0               dev_0    <develop>
[conda] torchmetrics              0.10.0                   pypi_0    pypi
[conda] torchrec-nightly          2022.10.14               pypi_0    pypi
[conda] torchtext                 0.14.0a0+4570a56           dev_0    <develop>
[conda] torchvision               0.15.0a0+f467349           dev_0    <develop>
[conda] torchx-nightly            2022.10.17               pypi_0    pypi
[conda] vector-quantize-pytorch   0.9.2                    pypi_0    pypi

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @mlazos @soumith @yanboliang @chunyuan-w @Xia-Weiwen @desertfire

penguinwu commented 9 months ago

@zou3519 Speculatively assign to you to investigate if the issue is still valid? If not, please close.

(Feel free to unassign, if the problem still exists and you have no bandwidth to fix it soon)