Open jcnossen opened 1 year ago
linux_env.txt This is the correct linux environment I used, the previous one I exported from the conda base env without pytorch
@jcnossen do you have more context about why you're composing torch.jit.script with vmap? Our stance so far is that we will likely not add torch.jit.script support for vmap but instead add torch.compile support for vmap
Hi, thanks for commenting. It's very convenient to write a lot of math heavy functions with torch.jit.script (speeding up very significantly), and still be able to do vmap in combination with jacfwd (It seems like vmap is the only way currently to do autodifferentiated jacobians on a large batches). Also I am mainly using windows for this use case, where torch.compile support is not finished yet. I was planning on using this but if not I can still implement the derivative manually though.
🐛 Describe the bug
I'm getting the following warning which hints at suboptimal speed, and doesn't look like it should happen at any point.
It seems to occur only when erf is passed an array instead of scalar. I tested this both on pytorch 2.0 and on the nightly build (version is collected below), and on windows and ubuntu (see collected environments)
Versions
windows_env.txt linux_env.txt
cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @zou3519 @Chillee @samdow @soumith @kshitij12345 @janeyx99