pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

No Batching rules for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet causes significant slowdown for per-sample gradients with torch.linalg.slogdet #984

Open AlphaBetaGamma96 opened 2 years ago

AlphaBetaGamma96 commented 2 years ago

TL;DR - torch.linalg.slogdet is over one order of magnitude slower in computing per-sample gradients in the latest nightly version of PyTorch/FuncTorch (1.13.0.dev20220721 / 0.3.0a0+e8a68f4) than a previous version of PyTorch/FuncTorch (1.12.0a0+git7c2103a / 0.2.0a0+9d6ee76) compiled from source. This seems to be due to the lack of batching rules for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet.

Thanks! :)

Hi All,

I've recently noticed that my code significantly slowed down (by around an order of magnitude) when moving from PyTorch 1.12 to 1.13. I've made a minimal reproducible example to highlight this issue. For reference, this issue was starting from #979 with some more info there, although the issue has been solved and a new issue was open as per @vfdev-5 suggestion.

The MRE below computes per-sample gradients with respect to the parameters for the laplacian of a model w.r.t its inputs. The script will compute the per-sample gradients for N inputs from 1 to 6 and show the walltime, then I decide to use torch.profile.profiler to give a more clear benchmark for N=4.

I've benchmarked two versions of PyTorch/FuncTorch. The first version was made from source (and can be found here). The only thing that is changed is the slogdet_backward formula which you can find here. The full version for this "old-source" version is,

PyTorch version:    1.12.0a0+git7c2103a
CUDA version:       11.6
FuncTorch version:  0.2.0a0+9d6ee76

The other version is the latest nightly (hereafter referred to as "nightly"). The full version of this "nightly" version is,

PyTorch version:    1.13.0.dev20220721
CUDA version:       11.6
FuncTorch version:  0.3.0a0+e8a68f4

A comparison in walltime (measured in seconds) as N increases from 1 to 6 is as follows

N | [old-source] | [nightly]
1 |    0.5719    |   2.4907     #first call is slow because ?
1 |    0.0133    |   2.0593
2 |    0.0870    |   2.4496 
3 |    0.1153    |   2.9293 
4 |    0.1129    |   3.3715
5 |    0.1576    |   3.8302
6 |    0.2059    |   4.2622

The torch.profile.profiler case of N = 4 for the "old-source" version is shown below and is stored by cuda_time_total

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.38%     417.000us         4.25%       4.694ms      82.351us       0.000us         0.00%     127.220ms       2.232ms            57  
                                               aten::mm         0.69%     765.000us         3.77%       4.169ms      64.138us      25.971ms        24.41%     117.336ms       1.805ms            65  
                                              aten::bmm         0.45%     493.000us         1.06%       1.168ms      43.259us      64.058ms        60.20%      87.447ms       3.239ms            27  
       autograd::engine::evaluate_function: MmBackward0         0.12%     131.000us         1.37%       1.513ms     168.111us       0.000us         0.00%      42.798ms       4.755ms             9  
                                            MmBackward0         0.04%      40.000us         1.22%       1.347ms     149.667us       0.000us         0.00%      42.630ms       4.737ms             9  
                                   volta_dgemm_64x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      41.116ms        38.64%      41.116ms       4.112ms            10  
    autograd::engine::evaluate_function: AddmmBackward0         0.09%     103.000us         1.71%       1.890ms     189.000us       0.000us         0.00%      21.334ms       2.133ms            10  
                                  volta_dgemm_128x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      20.883ms        19.62%      20.883ms       3.481ms             6  
                                         AddmmBackward0         0.04%      39.000us         1.13%       1.254ms     125.400us       0.000us         0.00%      19.551ms       1.955ms            10  
                                   volta_dgemm_64x64_tn         0.00%       0.000us         0.00%       0.000us       0.000us      14.590ms        13.71%      14.590ms       2.432ms             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 110.510ms
Self CUDA time total: 106.412ms

However, in the case of using the latest "nightly" version. The MRE significantly slows down and the torch.profile.profiler is dominated by the following commands aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 aten::_linalg_solve_ex        14.79%        1.125s       141.49%       10.763s     212.108us       0.000us         0.00%        2.966s      58.452us         50741  
                                     aten::linalg_solve         0.09%       7.174ms       117.89%        8.967s     560.456ms       0.000us         0.00%        2.284s     142.775ms            16  
                                  aten::linalg_solve_ex         0.00%      37.000us        75.04%        5.708s     475.648ms       0.000us         0.00%        1.513s     126.102ms            12  
autograd::engine::evaluate_function: LinalgSolveExBa...         0.00%     122.000us        62.86%        4.781s     683.040ms       0.000us         0.00%        1.275s     182.126ms             7  
                                 LinalgSolveExBackward0         0.00%      76.000us        62.85%        4.781s     683.008ms       0.000us         0.00%        1.275s     182.124ms             7  
                                  aten::linalg_lu_solve         9.34%     710.739ms        35.82%        2.724s      55.427us     643.114ms        33.69%     831.927ms      16.926us         49152  
                              aten::linalg_lu_factor_ex         7.28%     553.883ms        20.95%        1.593s      27.784us     661.250ms        34.64%     721.665ms      12.585us         57344  
                                  aten::_linalg_slogdet         4.59%     349.122ms        72.77%        5.535s     658.539us       0.000us         0.00%     677.273ms      80.580us          8405  
void getf2_cta_32x32<double, double>(int, int, int, ...         0.00%       0.000us         0.00%       0.000us       0.000us     540.579ms        28.32%     540.579ms       9.427us         57344  
void trsm_batch_left_lower_kernel<double>(cublasTrsm...         0.00%       0.000us         0.00%       0.000us       0.000us     277.594ms        14.54%     277.594ms       5.648us         49152  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.607s
Self CUDA time total: 1.909s

functorch also prompts me with a UserWarning that batching rules do not exists for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet and it defaults to a for-loop which will affect performance.

~/pytorch_nightly/debug/per-sample-elocal.py:49: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
  sgn, logabs = torch.linalg.slogdet(mat)

~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_solve_ex. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)

The full script to reproduce this error can be found below.

import torch
import torch.nn as nn
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity

import functorch
from functorch import jacrev, jacfwd, hessian, make_functional, vmap, grad

import time 

_ = torch.manual_seed(0)
torch.set_default_dtype(torch.float64)

#version info
print("PyTorch version:   ", torch.__version__)
print("CUDA version:      ", torch.version.cuda)
print("FuncTorch version: ", functorch.__version__)

#time with torch synchronization
def sync_time() -> float:
  torch.cuda.synchronize()
  return time.perf_counter()

class model(nn.Module):

  def __init__(self, num_inputs, num_hidden):
    super(model, self).__init__()

    self.num_inputs=num_inputs
    self.func = nn.Tanh()

    self.fc1 = nn.Linear(2, num_hidden)
    self.fc2 = nn.Linear(num_hidden, num_inputs)

  def forward(self, x):
    """
    Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
    """

    idx=len(x.shape)             #creates args for repeat if vmap is used or not
    rep=[1 for _ in range(idx)]
    rep[-2] = self.num_inputs
    g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
    f = torch.cat((x,g), dim=-1)

    h = self.func(self.fc1(f))

    mat = self.fc2(h)
    sgn, logabs = torch.linalg.slogdet(mat)
    return sgn, logabs

#=================================================================================================#
#Profile code for N=1 to 6
#=================================================================================================#

B=4096 #batch
N=2    #input nodes
H=128  #number of hidden nodes
device=torch.device("cuda")

for N in [1,1,2,3,4,5,6]:

  net = model(N, H)
  net = net.to(device)

  x = torch.randn(B,N,1,device=device) #input data
  fnet, params = make_functional(net)

  def logabs(params, x):
    _, logabs = fnet(params, x)
    return logabs

  def kinetic_functorch(params, X):
    #do once, and re-use via has_aux?
    calc_jacobian = jacrev(logabs, argnums=1) 
    #can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
    calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1) 

    return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)

  #per-sample gradients for local energy w.r.t params via FuncTorch
  t1=sync_time()
  elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
  t2=sync_time()

  print("N: %2i | Walltime: %6.4f (s)" % (N, t2-t1))

#=================================================================================================#
#Profile code for N=4
#=================================================================================================#

N=4
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
  net = model(N, H)
  net = net.to(device)

  x = torch.randn(B,N,1,device=device) #input data
  fnet, params = make_functional(net)

  def logabs(params, x):
    _, logabs = fnet(params, x)
    return logabs

  def kinetic_functorch(params, X):
    #do once, and re-use via has_aux?
    calc_jacobian = jacrev(logabs, argnums=1) 
    #can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
    calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1) 

    return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)

  #per-sample gradients for local energy w.r.t params via FuncTorch
  t1=sync_time()
  elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
  t2=sync_time()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Thanks in advance! :)

zou3519 commented 2 years ago

cc @samdow -- were these the ones you were planning on adding? I haven't checked if these fall into the low-hanging-fruit category

samdow commented 2 years ago

Yep! I have changes for all of these functions locally so let me pick them onto the new repo and put up some PRs

zou3519 commented 2 years ago

@samdow did https://github.com/pytorch/pytorch/pull/82177 cover all of these?

samdow commented 2 years ago

Not quite! Discussed a bit offline but @AlphaBetaGamma96 I'm hoping to get https://github.com/pytorch/pytorch/pull/82814 in soon and then I'll double check that we see speedups for your example (thanks for the repro!). Sorry for the delay--we ran into some AD related bugs because of adding these rules

AlphaBetaGamma96 commented 2 years ago

No need to apologize for the delay @samdow, thanks for solving the batch rule! Fingers crossed it all works!

samdow commented 2 years ago

Hi @AlphaBetaGamma96! Just wanted to let you know that linalg_solve just landed https://github.com/pytorch/pytorch/pull/82814. I tested locally that this example ran significantly faster after the fix than before the fix (exact numbers around the same order of magnitude to yours)

Thanks for the issue and thanks for your patience as we worked through some AD issues

AlphaBetaGamma96 commented 2 years ago

Hi @samdow, thanks for fixing this issue! A bit of a silly question, but I remember reading somewhere that functorch is being merged directly into pytorch (if that's the correct phrase). So, would I have to just download the latest nightly of pytorch (and now just ignore functorch), or do I just update functorch to its latest version as well as pytorch?

samdow commented 2 years ago

Not a silly question, there's been a lot of change in the past couple of weeks. Yes the main development of functorch is being done in pytorch/pytorch. So if you're building from source, you'll want to build pytorch master, cd into the functorch directory, and then build functorch.

Or (this workflow may occasionally break for ~a day and I don't do this, so let me know if it doesn't work for you), you can download pytorch nightly and then build the newest version of functorch against that. Options for getting functorch this way are either (1) downloading pytorch, cd-ing into the functorch directory, and building that or (2) downloading the functorch repo and building that (we aren't developing here but have a read-only sync for the moment)

Currently, sadly, it doesn't just work to download pytorch nightly and get the newest functorch with it. We have people actively working on making this work and we can keep you updated

Let me know if any of that doesn't make sense!

AlphaBetaGamma96 commented 2 years ago

Hi @samdow, so just to check, I need to download the latest pytorch nightly from https://pytorch.org/get-started/locally/ and then install functorch from source (from https://github.com/pytorch/functorch#installing-functorch-main), and that should be ok? (Assuming I've correctly understood what you've said)

EDIT: That seems to have worked

AlphaBetaGamma96 commented 2 years ago

For completeness, I thought I'd share the results for the latest nightly version

PyTorch version:    1.13.0.dev20220820
CUDA version:       11.6
FuncTorch version:  0.3.0a0+86a9049
N:  1 | Walltime: 0.4445 (s)
N:  1 | Walltime: 0.0107 (s)
N:  2 | Walltime: 0.0928 (s)
N:  3 | Walltime: 0.1296 (s)
N:  4 | Walltime: 0.1261 (s)
N:  5 | Walltime: 0.1634 (s)
N:  6 | Walltime: 0.2002 (s)
STAGE:2022-08-20 18:11:05 27507:27507 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2022-08-20 18:11:05 27507:27507 ActivityProfilerController.cpp:300] Completed Stage: Collection
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::bmm         0.52%     620.000us         2.01%       2.397ms      40.627us      88.138ms        73.46%     172.966ms       2.932ms            59  
                                           aten::matmul         0.55%     662.000us         2.93%       3.503ms      61.456us       0.000us         0.00%     158.351ms       2.778ms            57  
                                               aten::mm         0.39%     468.000us         1.95%       2.336ms      48.667us      12.092ms        10.08%      87.588ms       1.825ms            48  
    autograd::engine::evaluate_function: AddmmBackward0         0.07%      86.000us         1.27%       1.517ms     151.700us       0.000us         0.00%      38.730ms       3.873ms            10  
                                   volta_dgemm_64x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      37.369ms        31.14%      37.369ms       3.737ms            10  
                                         AddmmBackward0        -0.19%    -225.000us         0.90%       1.071ms     107.100us       0.000us         0.00%      36.673ms       3.667ms            10  
                                  volta_dgemm_128x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      33.590ms        27.99%      33.590ms       5.598ms             6  
      autograd::engine::evaluate_function: BmmBackward0         0.03%      30.000us         0.66%     791.000us     131.833us       0.000us         0.00%      32.940ms       5.490ms             6  
                                           BmmBackward0        -0.10%    -122.000us         0.61%     734.000us     122.333us       0.000us         0.00%      32.758ms       5.460ms             6  
autograd::engine::evaluate_function: LinalgSolveExBa...         0.02%      24.000us        63.43%      75.798ms      18.950ms       0.000us         0.00%      15.985ms       3.996ms             4  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 119.494ms
Self CUDA time total: 119.989ms
AlphaBetaGamma96 commented 2 years ago

Hi @samdow, apologies for re-opening this issue but could the torch.linalg.lu* functions also be added for a batching rule? It seems that when torch.linalg.slogdet is called it calls torch.linalg.lu for the decomposition (in order to perform the determinant call within the log-domain) which doesn't seem to have a batching rule. I've posted the warning below and it highlights aten::linalg_lu_solve as not having a batching rule.

~/main.py:201: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_lu_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-vire9c5a/functorch/csrc/BatchedFallback.cpp:83.)
  sgns, logabss = torch.slogdet(matrices * torch.exp(log_envs))
~/anaconda3/envs/pytorch_nightly_env/lib/python3.10/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_lu_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-vire9c5a/functorch/csrc/BatchedFallback.cpp:83.)
zou3519 commented 2 years ago

I added linalg_lu_solve yesterday, could you reinstall the latest pytorch nightly and try again?

AlphaBetaGamma96 commented 1 year ago

Hi @zou3519, sorry for the late response. I've installed the latest pytorch nightly and the UserWarning isn't there anymore. Thank you!

EDIT: removed issue with functorch install. Fresh install works fine!

AlphaBetaGamma96 commented 1 year ago

Hi @zou3519, I've just noticed that if torch.slogdet (instead of torch.linalg.slogdet) is used it defaults to a for-loop and stack, but torch.linalg.slogdet works fine as expected. I assume torch.slogdet is going to be removed in a future update (as it's moving the linalg library), but I thought I'd mention it here in case this problem emerges in other situations with other functions.

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /opt/conda/conda-bld/pytorch_1664781140419/work/aten/src/ATen/functorch/BatchedFallback.cpp:82.)
zou3519 commented 1 year ago

I think I know what is going on here, will fix soon. EDIT: fix over at https://github.com/pytorch/pytorch/pull/86815 . Thanks as always for reporting bugs, @AlphaBetaGamma96.

It doesn't look like we've actually deprecated torch.slogdet in favor of torch.linalg.slogdet, is that right @lezcano ? (https://pytorch.org/docs/1.13/generated/torch.slogdet.html?highlight=torch+slogdet#torch.slogdet). In that case we do want vmap support for both operators since users aren't being directed to use torch.linalg.slogdet over torch.slogdet.

lezcano commented 1 year ago

We haven't deprecated it. It's left there as an alias:

// Alias
std::tuple<Tensor, Tensor> slogdet(const Tensor& A) {
  return at::linalg_slogdet(A);
}

std::tuple<Tensor&, Tensor&> slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
  return at::linalg_slogdet_out(sign, logabsdet, A);
}