pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

25% Performance regression from v0.1.1 to 0.2.0 when calculating hessian #989

Open yueyericardo opened 1 year ago

yueyericardo commented 1 year ago

Hi developers,

After I upgraded functorch from v0.1.1 to 0.2.0, I noticed a 25% performance regression when calculating hessian, please check the following benchmark result and the attached benchmark script.

Please let me know if I did anything wrong, and also whether the perf regression could be fixed. Thanks!

Benchmark result

Benchmark result on NVIDIA A100

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 0.00e+00
max hessian    error: functorch: 0.00e+00
reference_hessian: 61.837 ms
functorch_hessian: 29.474 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 62.519 ms
functorch_hessian: 39.666 ms  (0.75 X)

Benchmark result on NVIDIA A6000

# torch 111 and functorch 0.1.1
===== benchmark without backward =====
max pred       error: functorch: 1.49e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 65.984 ms
functorch_hessian: 33.662 ms

# torch 112 and functorch 0.2.0
===== benchmark without backward =====
max pred       error: functorch: 1.86e-08
max hessian    error: functorch: 0.00e+00
reference_hessian: 67.285 ms
functorch_hessian: 49.723 ms (0.68 X)

benchmark script

benchmark.py

import time
import argparse
from functorch import vmap, jacrev, jacfwd
import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = False

_ = torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
D1 = 2  # x, y
D2 = 3  # u, v, p
B = 10000
x = torch.randn(B, D1).to(device)
run_backward = False

model = nn.Sequential(
    nn.Linear(D1, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, D2),
).to(device)

def predict(x):
    torch.cuda.nvtx.range_push("forward")
    out = model(x)
    torch.cuda.nvtx.range_pop()
    return out, out  # return two outputs is needed for jacrev auxiliary object

def reference_hessian():
    x_ = x.clone().requires_grad_()
    ones = torch.ones(B, device=x.device)
    pred, _ = predict(x_)
    jacobian_rows = [None] * D2
    hessian_rows = [None] * (D2 * D1)
    for i in range(D2):
        torch.cuda.nvtx.range_push("autograd jacobian")
        jacobian_rows[i] = torch.autograd.grad(pred[:, i], x_, ones, create_graph=True)[
            0
        ]
        torch.cuda.nvtx.range_pop()

    for i in range(D2):
        for j in range(D1):
            torch.cuda.nvtx.range_push("autograd hesian")
            hessian_rows[i * D1 + j] = torch.autograd.grad(
                jacobian_rows[i][:, j], x_, ones, create_graph=True
            )[0]
            torch.cuda.nvtx.range_pop()

    jacobian = torch.stack(jacobian_rows)  # [D2, B, D1]
    hessian = torch.stack(hessian_rows)  # [D2 * D1, B, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian.transpose(0, 1), pred

def functorch_hessian():
    x_ = x.clone().requires_grad_()
    hessian, pred = vmap(
        jacfwd(jacrev(predict, argnums=0, has_aux=True), argnums=0, has_aux=True),
        in_dims=0,
    )(
        x_
    )  # [B, D2, D1, D1]
    if run_backward:
        l = hessian.sum()
        l.backward()
    return hessian, pred

def validate_result():
    # test functorch result
    ref_hes, ref_pred = reference_hessian()
    ft_hes, ft_pred = functorch_hessian()
    ref_hes = ref_hes.view_as(ft_hes)
    print(f"max pred       error: functorch: {(ref_pred - ft_pred).max():.2e}")
    print(f"max hessian    error: functorch: {(ref_hes - ft_hes).max():.2e}")

def benchmark(func):
    N = 20

    torch.cuda.synchronize()
    start = time.time()

    for i in range(N):
        torch.cuda.nvtx.range_push(func.__name__)
        _ = func()
        torch.cuda.nvtx.range_pop()

    torch.cuda.synchronize()
    time_ms = ((time.time() - start) / N) * 1000
    print(f"{func.__name__}: {time_ms:.3f} ms")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-b", "--backward", default=False, action="store_true")
    args = parser.parse_args()
    if args.backward:
        run_backward = True
        print("===== benchmark with backward =====")
    else:
        print("===== benchmark without backward =====")

    validate_result()

    # warm up
    for i in range(10):
        reference_hessian()
        functorch_hessian()

    # benchmark hessian
    benchmark(reference_hessian)
    benchmark(functorch_hessian)
yueyericardo commented 1 year ago

ping @samdow @zou3519

zou3519 commented 1 year ago

Thanks for the report, we'll take a look soon

zou3519 commented 1 year ago

Bisected to https://github.com/pytorch/pytorch/pull/75195/files. https://github.com/pytorch/pytorch/pull/75195/files by itself may not be a problem, perhaps the problem is our batching rule for mv.

@yueyericardo is the repro you provided the entire model, or is it a subset of some model that you're running?

zou3519 commented 1 year ago

cc @lezcano @ezyang for https://github.com/pytorch/pytorch/pull/75195 -- this led to a performance regression in functorch. I'm not sure what the original intent of the PR is (there are no tests). I'm still trying to root cause this, but it is a bit difficult to visualize. What are the chances we could revert that PR?

(I've confirmed that reverting that single PR on pytorch/pytorch master makes the performance regression go away)

yueyericardo commented 1 year ago

Great thanks to @zou3519 for the quick debugging!! I'm working on NVIDIA Modulus project, we are using functorch because it provides a lot of perfs for the Jacobian and Hessian calculations. The minimal repro I provided is only a subset of our model to demonstrate the performance regression.

Thanks again!

zou3519 commented 1 year ago

@yueyericardo - for the original model itself, is the performance regression also 25%, or is it a smaller number? Is the original model public? One thing we can do to prevent future regressions is to check the original model into https://github.com/pytorch/benchmark.

I've noticed a lot of other similar models where folks have a nn.Sequential that is just made up of nn.Linear and activations and need to compute a vmap(jacrev or vmap(hessian of the quantity, so we could also potentially just check your script into torchbench otherwise.

yueyericardo commented 1 year ago

Hi @zou3519 Our source code is free to download from the website, but it is not developed on GitHub. And our code base is also might too large to put into pytorch/benchmark.

Yes, exactly! I believe the minimal repro I provided is enough to prevent future regression for our model. Thanks!

lezcano commented 1 year ago

I think that rather than blindly reverting, we should get to the root of the problem, as it is very weird to get such a regression when dispatching from a more general function to a more concrete (that was the reason for that PR).

Things that come to mind are:

If the answer to the above two is no, then this performance issue is likely on the functorch end and should be fixed. Otherwise, it's on the cuBLAS end and should be reported to NVIDIA

cc @ngimel @xwang233 @ivanyashchuk

ezyang commented 1 year ago

@Lezcano It's fair to submit the upstream bugs, but if we know that our upstream library's general kernel has better perf than a specialized one, we might as well use it.

zou3519 commented 1 year ago

@Lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?

Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.

ezyang commented 1 year ago

No, I don't think so. The PR is supposed to make the kernel run faster.

lezcano commented 1 year ago

fwiw, I think this may be related to the open PR I have to avoid copies in matmul. Could you check whether https://github.com/pytorch/pytorch/pull/76828 fixes this?

In any case, I'm fine with reverting, but we should investigate what's causing this regardless

ngimel commented 1 year ago

@Lezcano @ezyang let's say that we did revert the PR (because we're trying to release PyTorch 1.12.1 as soon as possible). Would it cause any other problems?

Because the motivation was "dispatching from a more general function to a more concrete [function]", it sounds like this wouldn't change very much else.

That PR is fixing spurious resize warnings that were previously generated, and by itself is supposed to speedup things by avoiding squeeze/unsqueeze calls which are not free (and especially not free when autograd is needed). As for more general/more concrete function performance, we should investigate this, but I doubt that's the case.

ngimel commented 1 year ago

@zou3519 can you by any chance collect profiling results for old and new versions?

zou3519 commented 1 year ago

In any case, I'm fine with reverting, but we should investigate what's causing this regardless

I agree this warrants more investigation. We've got a problem in that there is a timeline for 1.12.1, and I am not sure how long it is going to take to actually get to the bottom of this.

@zou3519 can you by any chance collect profiling results for old and new versions?

I can try but it's been a long time since I touched nvprof or nsight profiler, so I will need to relearn the magic invocations.

but we should investigate what's causing this regardless

Since we changed the mm to mv, functorch generates different code for the vmap(jacrev(jacfwd(mm)) as opposed to vmap(jacrev(jacfwd(mv)). It's plausible that the problem is that "functorch should generate better code"; we're still digging into it

ngimel commented 1 year ago

Isn't 1.12.1 done already? No nsight needed, just torch profiler should be enough. (with torch.profiler.profile as p(): and print key averages and export chrome trace in the end).

yueyericardo commented 1 year ago

FYI, the nsight profiling result before

Time    Total Time  Instances   Avg Med Min Max StdDev  Name
64.7%   928.624 ms  1000    928.623 μs  867.621 μs  297.345 μs  1.686 ms    493.406 μs  ampere_sgemm_128x64_nn
10.3%   147.304 ms  250 589.217 μs  589.187 μs  588.643 μs  590.435 μs  275 ns  ampere_sgemm_128x128_tn
9.4%    134.966 ms  1300    103.819 μs  70.432 μs   5.344 μs    194.209 μs  75.470 μs   void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, at::native::AddFunctor<float>>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
6.8%    98.146 ms   900 109.050 μs  94.848 μs   70.081 μs   177.921 μs  43.651 μs   void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
5.4%    77.550 ms   350 221.572 μs  300.321 μs  18.080 μs   302.050 μs  124.914 μs  ampere_sgemm_128x64_tn
0.8%    10.981 ms   1450    7.572 μs    3.552 μs    2.880 μs    26.112 μs   7.973 μs    void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.7%    9.949 ms    150 66.324 μs   82.928 μs   27.968 μs   88.672 μs   26.863 μs   ampere_sgemm_32x32_sliced1x4_nn
0.6%    8.074 ms    300 26.912 μs   26.976 μs   25.696 μs   27.776 μs   390 ns  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.5%    7.713 ms    50  154.257 μs  154.177 μs  153.184 μs  155.905 μs  634 ns  ampere_sgemm_32x128_nn
0.3%    4.847 ms    100 48.468 μs   47.664 μs   34.688 μs   63.393 μs   13.397 μs   ampere_sgemm_32x32_sliced1x4_tn
0.2%    3.342 ms    650 5.141 μs    5.344 μs    4.256 μs    5.728 μs    423 ns  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 8)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
0.1%    2.028 ms    50  40.569 μs   40.512 μs   40.032 μs   41.536 μs   272 ns  ampere_sgemm_64x64_nn
0.1%    994.305 μs  50  19.886 μs   19.856 μs   19.649 μs   20.896 μs   169 ns  ampere_sgemm_128x32_nn
0.0%    435.235 μs  100 4.352 μs    4.384 μs    4.160 μs    4.576 μs    100 ns  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)

after (the third row 18.4% on a copy kernel)

Time    Total Time  Instances   Avg Med Min Max StdDev  Name
21.8%   433.409 ms  500 866.818 μs  866.756 μs  865.028 μs  870.116 μs  881 ns  ampere_sgemm_32x128_nt
21.2%   420.235 ms  250 1.681 ms    1.681 ms    1.678 ms    1.685 ms    813 ns  ampere_sgemm_128x64_nn
18.4%   365.797 ms  2850    128.349 μs  83.809 μs   4.096 μs    608.931 μs  173.023 μs  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::<unnamed>::direct_copy_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
10.8%   213.750 ms  300 712.501 μs  588.802 μs  588.195 μs  1.335 ms    277.174 μs  ampere_sgemm_128x128_tn
8.1%    160.586 ms  1300    123.527 μs  93.057 μs   5.089 μs    266.497 μs  92.181 μs   void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::CUDAFunctor_add<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
7.9%    156.152 ms  500 312.304 μs  312.290 μs  310.145 μs  314.689 μs  781 ns  ampere_sgemm_128x32_nn
5.9%    117.864 ms  900 130.960 μs  129.824 μs  77.504 μs   203.681 μs  49.220 μs   void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::BinaryFunctor<float, float, float, void at::native::<unnamed>::threshold_kernel_impl<float>(at::TensorIteratorBase &, T1, T1)::[lambda(float, float) (instance 1)]>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
3.3%    66.223 ms   50  1.324 ms    1.324 ms    1.324 ms    1.326 ms    397 ns  ampere_sgemm_128x128_tt
1.1%    21.094 ms   1850    11.402 μs   3.584 μs    2.816 μs    65.857 μs   12.662 μs   void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::Array<char *, (int)1>>(int, T2, T3)
0.4%    8.071 ms    300 26.904 μs   26.976 μs   25.792 μs   28.128 μs   392 ns  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::launch_clamp_scalar(at::TensorIteratorBase &, c10::Scalar, c10::Scalar, at::native::detail::ClampLimits)::[lambda() (instance 1)]::operator ()() const::[lambda() (instance 14)]::operator ()() const::[lambda(float) (instance 1)], at::detail::Array<char *, (int)2>>(int, T2, T3)
0.4%    7.903 ms    50  158.063 μs  158.049 μs  156.609 μs  159.617 μs  634 ns  ampere_sgemm_32x128_nn
0.4%    7.359 ms    100 73.594 μs   73.600 μs   72.000 μs   78.368 μs   1.223 μs    ampere_sgemm_32x32_sliced1x4_nt
0.2%    3.242 ms    50  64.844 μs   64.928 μs   62.464 μs   65.761 μs   528 ns  ampere_sgemm_32x32_sliced1x4_tn
0.1%    2.255 ms    100 22.551 μs   22.688 μs   20.865 μs   24.192 μs   1.106 μs    void gemmSN_NN_kernel<float, (int)256, (int)4, (int)2, (int)8, (int)3, (int)4, (bool)0, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<const float>, cublasGemvTensorStridedBatched<float>>(cublasGemmSmallNParams<T9, T10, T11, T1>)
0.1%    2.043 ms    100 20.427 μs   20.416 μs   20.160 μs   22.656 μs   257 ns  ampere_sgemm_128x32_tn
0.0%    432.292 μs  100 4.322 μs    4.336 μs    4.128 μs    4.640 μs    97 ns   void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::FillFunctor<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)
lezcano commented 1 year ago

@ngimel the PR that fixed the warnings was already merged, this one is just concerned about avoiding copies. One of the cases where it elides a copy is when you multiply a matrix by a batch of matrices. This is exactly the batched version of the vector-matrix product that"s causing the regression. That's why I think it may fix it

ngimel commented 1 year ago

That must be coming from functorch, as that PR doesn't introduce any additional copies.

lezcano commented 1 year ago

yup, I think https://github.com/pytorch/pytorch/pull/76828 would fix the regression

zou3519 commented 1 year ago

I patched in https://github.com/pytorch/pytorch/pull/76828 and the above script ends up OOM-ing :/

zou3519 commented 1 year ago

Isn't 1.12.1 done already?

Not yet, we have a chance to change it (or request a 1.12.2 if necessary since this regression is large and these types of models are typical functorch usage)

lezcano commented 1 year ago

Re. OOM. Wow, that's certainly unexpected. I'm not sure what's the best way to follow up on that. Probably @ngimel has a better idea how to proceed.

Regardless, landing that PR (the avoid copies matmul...) will be tricky due to some discrepancies in the accuracy of mm vs mv that we found for float16. As such, I think the most realistic way forward would be to revert the offending PR and investigate what's causing that OMM after 1.12.1 is released

ngimel commented 1 year ago

the PR that fixed the warnings was already merged

I don't think that's true, #75195 itself is fixing a warning (otherwise a user-supplied correct 1d out was sent to mm, and mm complained about resizing it to 2d).

lezcano commented 1 year ago

I think the one that fixed the out version was https://github.com/pytorch/pytorch/pull/75197 and a previous one in that stack, but it may be the case that 75195 also fixed an out warning, I don't remember now.

zou3519 commented 1 year ago

Here is my attempt at a smaller repro that runs on a single version of pytorch. It computes a different quantity, vmap(jacrev(smaller_predict)), but also exhibits performance differences (1.3ms vs 7ms on my machine).

https://gist.github.com/zou3519/3869d460f8bcb12799967e08a5998d9c

In the gist there's also a trace (acquired via make_fx) of what the graph looks like, using the "old linear" (aka the implementation of matmul in 1.11) vs the "new linear" (the implementation of matmul in 1.12), if anyone's feeling ambitious about reading traces

zou3519 commented 1 year ago

@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

yueyericardo commented 1 year ago

Hi @zou3519, thanks for the following up! We are having some internal discussions regarding this and will come back to you tomorrow.

samdow commented 1 year ago

From looking at this a bit, I think what happened is:

During backwards,

We can also validate that this is our issue since we hit pre-regression performance numbers by changing functorch's mv batch rule here:

  auto other_ = moveBatchDimToFront(other, other_bdim);
  auto self_ = at::movedim(self, 0, 1);
  auto result = at::matmul(other_, self_);
  return std::make_tuple( std::move(result), 0 );

This doesn't trigger the copy since the batch dimension for the saved relu activation is now the first dimension. However, this may hit other perf issue from the transposing both self and result


This is the smallest subset of @zou3519's trace where I can see the perf differences

mm_2 = torch.randn(60000, 512)
relu = torch.randn(10000, 512)
def old_linear_faster():
    mm_2_view = mm_2.view([10000, 6, 1, 512])
    mm_2_view_squeeze = mm_2_view.squeeze(2)

    relu_view = relu.view([10000, 1, 512])

    threshold_backward = torch.ops.aten.threshold_backward(mm_2_view_squeeze, relu_view, 0)
    return threshold_backward

bmm = torch.randn(10000, 512, 6)
other_relu = torch.randn(512, 10000)
def new_linear_faster():
    bmm_view = bmm.view([10000, 512, 6])
    bmm_view_permute = bmm_view.permute([0, 2, 1])

    other_relu_permute = other_relu.permute([1, 0])
    other_relu_permute_view = other_relu_permute.view([10000, 1, 512])

    threshold_backward = torch.ops.aten.threshold_backward(bmm_view_permute, other_relu_permute_view, 0)
    return threshold_backward

notably, if If I change the views on either bmm or other_relu to be contiguous constants, it has much faster performance. So it seems like threshold_backwards doesn't copy if only one of its inputs is not contiguous but does if both are not

ngimel commented 1 year ago

Thanks @samdow, copy from relu definitely seems to affect perf, however there's also another copy coming from MvBackward Screen Shot 2022-08-02 at 10 13 10 AM

Also, why does threshold_backward on discontiguous inputs trigger a copy? In eager threshold_backward should be able to handle them via tensorIterator

lezcano commented 1 year ago

As discussed, I think that mv_backward copy_ would be fixed by https://github.com/pytorch/pytorch/pull/76828. Now, Richard reports in https://github.com/pytorch/functorch/issues/989#issuecomment-1199807110 that patching in this fix results in OOM, so it's not clear what's going on there.

In my opinion, the path forward to fix this regression would be to:

  1. Figure out why is it OOM. I'm not sure how to tackle this, but perhaps @ngimel can help here
  2. Unblock and land https://github.com/pytorch/pytorch/pull/76828. I will try to get a small repro tomorrow and pass it on to the NVIDIA folks for them to investigate further.
zou3519 commented 1 year ago

@yueyericardo I got access to the modulus repo -- could you point me to which model in the modulus repo contains the nn.Sequential above please? We're figuring out how to check the above code into torchbench and I'm looking for some more context -- is there a name for computing the hessian for the nn.Sequential? Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

@yueyericardo quick bump on the above. We're looking to merge some form of the above code into our benchmark suite and additional information would be very helpful

yueyericardo commented 1 year ago

@yueyericardo quick bump on the above. We're looking to merge some form of the above code into our benchmark suite and additional information would be very helpful

@zou3519 Sorry, I already finished my internship there. I believe the next release of Modulus (this month) will include the functorch integration and will be available on GitLab. Including @NickGeneva from Modulus team for further communications.

lezcano commented 1 year ago

For what is worth, the fix https://github.com/pytorch/pytorch/pull/76828 is completely stalled on some errors on NestedTensor on some GPU architectures. If any one would like to chime in on how to tackle it, you are more than welcome.

akshaysubr commented 1 year ago

@zou3519 The modulus version with functorch support was made public last week. Here is the repo.

could you point me to which model in the modulus repo contains the nn.Sequential above please?

The specific place where functorch is used is in this wrapper class.

Do you generally run .backward() after computing the hessian? What are some representative input sizes (are the sizes in the benchmark script representative?)

Yes, after computing the hessian, we need to run .backward() to actually get the weight gradients. In terms of the representative input sizes, they are typically O(1)-O(10) and same for the outputs. The sizes in the benchmark script are representative for some of the simpler cases and they might go up from there but the hidden dimensions are usually the same as in the benchmark.

IvanYashchuk commented 1 year ago

@akshaysubr provided links do not work, maybe project access settings are not correctly set.

akshaysubr commented 1 year ago

@IvanYashchuk Sorry, should've mentioned that you have to apply for access to that repo here and after that those links should work.

zou3519 commented 1 year ago

I was able to add a version of the original benchmark to pytorch/benchmark so we now can prevent regressions in this model, so I'm lowering the priority of this issue. Leaving it open to discuss the changes to matmul above though.

lezcano commented 1 year ago

fwiw, the matmul PR is as stalled as it was before. Neither Ivan or me have been able to find the time to put together a standalone C repro for the cublas team.

lezcano commented 1 year ago

After the patching in the one-liner https://github.com/pytorch/pytorch/pull/75195/ on top of the stack https://github.com/pytorch/pytorch/pull/76828, I still get an OOM using the script in the OP. The issue happens to be that, somehow, we end up with a tensor of shape [10000, 512, 512] and strides equal to zero within bmm_out_cuda. Then, we try to materialize this tensor and we OOM. Interestingly enough, if you print the sizes/strides of the tensors that go through matmul you get

size t1
[512, 512]     
strides t1
[512, 1]  
size t2                                                                                                                                                                                                           
[10000, 512, 3]               
strides t2                                                                                                                                                                                       
[3, 30000, 1] 

so that tensor that gets to bmm looks like a batched version of the one that gets to the last if-else within matmul. Perhaps this is a feature of functorch, but I don't really understand it. Note that this would be fine if this tensor rather than strides equal to zero, it had strides equal to, say, [0, 512, 1], as we would not need to materialise it there. Now perhaps my question is, is this tensor with strides equal to zero reasonable, or is it a bug / optimisation opportunity in functorch, or whether this follows from some functorch feature, and I may be doing something wrong here.

cc @kshitij12345