Open AlphaBetaGamma96 opened 2 years ago
cc @samdow -- were these the ones you were planning on adding? I haven't checked if these fall into the low-hanging-fruit category
Yep! I have changes for all of these functions locally so let me pick them onto the new repo and put up some PRs
@samdow did https://github.com/pytorch/pytorch/pull/82177 cover all of these?
Not quite! Discussed a bit offline but @AlphaBetaGamma96 I'm hoping to get https://github.com/pytorch/pytorch/pull/82814 in soon and then I'll double check that we see speedups for your example (thanks for the repro!). Sorry for the delay--we ran into some AD related bugs because of adding these rules
No need to apologize for the delay @samdow, thanks for solving the batch rule! Fingers crossed it all works!
Hi @AlphaBetaGamma96! Just wanted to let you know that linalg_solve just landed https://github.com/pytorch/pytorch/pull/82814. I tested locally that this example ran significantly faster after the fix than before the fix (exact numbers around the same order of magnitude to yours)
Thanks for the issue and thanks for your patience as we worked through some AD issues
Hi @samdow, thanks for fixing this issue! A bit of a silly question, but I remember reading somewhere that functorch is being merged directly into pytorch (if that's the correct phrase). So, would I have to just download the latest nightly of pytorch (and now just ignore functorch), or do I just update functorch to its latest version as well as pytorch?
Not a silly question, there's been a lot of change in the past couple of weeks. Yes the main development of functorch is being done in pytorch/pytorch. So if you're building from source, you'll want to build pytorch master, cd into the functorch directory, and then build functorch.
Or (this workflow may occasionally break for ~a day and I don't do this, so let me know if it doesn't work for you), you can download pytorch nightly and then build the newest version of functorch against that. Options for getting functorch this way are either (1) downloading pytorch, cd-ing into the functorch directory, and building that or (2) downloading the functorch repo and building that (we aren't developing here but have a read-only sync for the moment)
Currently, sadly, it doesn't just work to download pytorch nightly and get the newest functorch with it. We have people actively working on making this work and we can keep you updated
Let me know if any of that doesn't make sense!
Hi @samdow, so just to check, I need to download the latest pytorch nightly from https://pytorch.org/get-started/locally/ and then install functorch from source (from https://github.com/pytorch/functorch#installing-functorch-main), and that should be ok? (Assuming I've correctly understood what you've said)
EDIT: That seems to have worked
For completeness, I thought I'd share the results for the latest nightly version
PyTorch version: 1.13.0.dev20220820
CUDA version: 11.6
FuncTorch version: 0.3.0a0+86a9049
N: 1 | Walltime: 0.4445 (s)
N: 1 | Walltime: 0.0107 (s)
N: 2 | Walltime: 0.0928 (s)
N: 3 | Walltime: 0.1296 (s)
N: 4 | Walltime: 0.1261 (s)
N: 5 | Walltime: 0.1634 (s)
N: 6 | Walltime: 0.2002 (s)
STAGE:2022-08-20 18:11:05 27507:27507 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-08-20 18:11:05 27507:27507 ActivityProfilerController.cpp:300] Completed Stage: Collection
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::bmm 0.52% 620.000us 2.01% 2.397ms 40.627us 88.138ms 73.46% 172.966ms 2.932ms 59
aten::matmul 0.55% 662.000us 2.93% 3.503ms 61.456us 0.000us 0.00% 158.351ms 2.778ms 57
aten::mm 0.39% 468.000us 1.95% 2.336ms 48.667us 12.092ms 10.08% 87.588ms 1.825ms 48
autograd::engine::evaluate_function: AddmmBackward0 0.07% 86.000us 1.27% 1.517ms 151.700us 0.000us 0.00% 38.730ms 3.873ms 10
volta_dgemm_64x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 37.369ms 31.14% 37.369ms 3.737ms 10
AddmmBackward0 -0.19% -225.000us 0.90% 1.071ms 107.100us 0.000us 0.00% 36.673ms 3.667ms 10
volta_dgemm_128x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 33.590ms 27.99% 33.590ms 5.598ms 6
autograd::engine::evaluate_function: BmmBackward0 0.03% 30.000us 0.66% 791.000us 131.833us 0.000us 0.00% 32.940ms 5.490ms 6
BmmBackward0 -0.10% -122.000us 0.61% 734.000us 122.333us 0.000us 0.00% 32.758ms 5.460ms 6
autograd::engine::evaluate_function: LinalgSolveExBa... 0.02% 24.000us 63.43% 75.798ms 18.950ms 0.000us 0.00% 15.985ms 3.996ms 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 119.494ms
Self CUDA time total: 119.989ms
Hi @samdow, apologies for re-opening this issue but could the torch.linalg.lu*
functions also be added for a batching rule? It seems that when torch.linalg.slogdet
is called it calls torch.linalg.lu
for the decomposition (in order to perform the determinant call within the log-domain) which doesn't seem to have a batching rule. I've posted the warning below and it highlights aten::linalg_lu_solve
as not having a batching rule.
~/main.py:201: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_lu_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-vire9c5a/functorch/csrc/BatchedFallback.cpp:83.)
sgns, logabss = torch.slogdet(matrices * torch.exp(log_envs))
~/anaconda3/envs/pytorch_nightly_env/lib/python3.10/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_lu_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-vire9c5a/functorch/csrc/BatchedFallback.cpp:83.)
I added linalg_lu_solve yesterday, could you reinstall the latest pytorch nightly and try again?
Hi @zou3519, sorry for the late response. I've installed the latest pytorch nightly and the UserWarning
isn't there anymore. Thank you!
EDIT: removed issue with functorch install. Fresh install works fine!
Hi @zou3519, I've just noticed that if torch.slogdet
(instead of torch.linalg.slogdet) is used it defaults to a for-loop and stack, but torch.linalg.slogdet
works fine as expected. I assume torch.slogdet
is going to be removed in a future update (as it's moving the linalg library), but I thought I'd mention it here in case this problem emerges in other situations with other functions.
UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /opt/conda/conda-bld/pytorch_1664781140419/work/aten/src/ATen/functorch/BatchedFallback.cpp:82.)
I think I know what is going on here, will fix soon. EDIT: fix over at https://github.com/pytorch/pytorch/pull/86815 . Thanks as always for reporting bugs, @AlphaBetaGamma96.
It doesn't look like we've actually deprecated torch.slogdet in favor of torch.linalg.slogdet, is that right @lezcano ? (https://pytorch.org/docs/1.13/generated/torch.slogdet.html?highlight=torch+slogdet#torch.slogdet). In that case we do want vmap support for both operators since users aren't being directed to use torch.linalg.slogdet over torch.slogdet.
We haven't deprecated it. It's left there as an alias:
// Alias
std::tuple<Tensor, Tensor> slogdet(const Tensor& A) {
return at::linalg_slogdet(A);
}
std::tuple<Tensor&, Tensor&> slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
return at::linalg_slogdet_out(sign, logabsdet, A);
}
TL;DR -
torch.linalg.slogdet
is over one order of magnitude slower in computing per-sample gradients in the latest nightly version of PyTorch/FuncTorch (1.13.0.dev20220721
/0.3.0a0+e8a68f4
) than a previous version of PyTorch/FuncTorch (1.12.0a0+git7c2103a
/0.2.0a0+9d6ee76
) compiled from source. This seems to be due to the lack of batching rules foraten::_linalg_solve_ex
,aten::linalg_solve
,aten::linalg_solve_ex
, andaten::_linalg_slogdet
.Thanks! :)
Hi All,
I've recently noticed that my code significantly slowed down (by around an order of magnitude) when moving from PyTorch 1.12 to 1.13. I've made a minimal reproducible example to highlight this issue. For reference, this issue was starting from #979 with some more info there, although the issue has been solved and a new issue was open as per @vfdev-5 suggestion.
The MRE below computes per-sample gradients with respect to the parameters for the laplacian of a model w.r.t its inputs. The script will compute the per-sample gradients for
N
inputs from 1 to 6 and show the walltime, then I decide to usetorch.profile.profiler
to give a more clear benchmark forN=4
.I've benchmarked two versions of PyTorch/FuncTorch. The first version was made from source (and can be found here). The only thing that is changed is the
slogdet_backward
formula which you can find here. The full version for this "old-source" version is,The other version is the latest nightly (hereafter referred to as "nightly"). The full version of this "nightly" version is,
A comparison in walltime (measured in seconds) as
N
increases from 1 to 6 is as followsThe
torch.profile.profiler
case ofN
= 4 for the "old-source" version is shown below and is stored bycuda_time_total
However, in the case of using the latest "nightly" version. The MRE significantly slows down and the
torch.profile.profiler
is dominated by the following commandsaten::_linalg_solve_ex
,aten::linalg_solve
,aten::linalg_solve_ex
, andaten::_linalg_slogdet
functorch
also prompts me with aUserWarning
that batching rules do not exists foraten::_linalg_solve_ex
,aten::linalg_solve
,aten::linalg_solve_ex
, andaten::_linalg_slogdet
and it defaults to a for-loop which will affect performance.The full script to reproduce this error can be found below.
Thanks in advance! :)