pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.97k stars 22.37k forks source link

torch.jit.script codegen warning with cuda and vmap #98557

Open jcnossen opened 1 year ago

jcnossen commented 1 year ago

🐛 Describe the bug

I'm getting the following warning which hints at suboptimal speed, and doesn't look like it should happen at any point.

To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
 (Triggered internally at C:\cb\pytorch_1000000000000\work\third_party\nvfuser\csrc\manager.cpp:340.)
  batched_outputs = func(*batched_inputs, **kwargs)```

This can be reproduced with the following code:

```import torch
from torch import vmap
import torch.jit

@torch.jit.script
def test(params):
    x0 = params[0]
    y = torch.arange(0, 64, dtype=torch.float32, device=params.device)
    return torch.cos(x0)*y

params = torch.zeros((200, 4), dtype=torch.float32, device='cuda')
torch.vmap(test, chunk_size=100)(params)

It seems to occur only when erf is passed an array instead of scalar. I tested this both on pytorch 2.0 and on the nightly build (version is collected below), and on windows and ubuntu (see collected environments)

Versions

windows_env.txt linux_env.txt

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @zou3519 @Chillee @samdow @soumith @kshitij12345 @janeyx99

jcnossen commented 1 year ago

linux_env.txt This is the correct linux environment I used, the previous one I exported from the conda base env without pytorch

zou3519 commented 1 year ago

@jcnossen do you have more context about why you're composing torch.jit.script with vmap? Our stance so far is that we will likely not add torch.jit.script support for vmap but instead add torch.compile support for vmap

jcnossen commented 1 year ago

Hi, thanks for commenting. It's very convenient to write a lot of math heavy functions with torch.jit.script (speeding up very significantly), and still be able to do vmap in combination with jacfwd (It seems like vmap is the only way currently to do autodifferentiated jacobians on a large batches). Also I am mainly using windows for this use case, where torch.compile support is not finished yet. I was planning on using this but if not I can still implement the derivative manually though.