flaport / torch_sparse_solve

A sparse KLU solver for PyTorch.
https://pypi.org/project/torch-sparse-solve
GNU Lesser General Public License v2.1
62 stars 4 forks source link

fail to pass the `torch.autograd.gradcheck` #17

Open zhf-0 opened 6 months ago

zhf-0 commented 6 months ago

Thank you very much for providing such a good tool!

My problem is that when the input A is a 'real' sparse matrix, not the sparse matrix converted from a dense matrix, the torch.autograd.gradcheck() function will throw an exception. The python program I use is

import scipy
import torch
class SparseSolve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        '''
        A is a torch coo sparse matrix
        b is a tensor
        '''
        if A.ndim != 2 or (A.shape[0] != A.shape[1]):
            raise ValueError("A should be a square 2D matrix.")

        A = A.coalesce()
        A_idx = A.indices().to('cpu').numpy()
        A_val = A.values().to('cpu').numpy()
        sci_A = coo_matrix((A_val,(A_idx[0,:],A_idx[1,:]) ),shape=A.shape)
        sci_A = sci_A.tocsr()

        np_b = b.detach().cpu().numpy()
        # Solver the sparse system
        if np_b.ndim == 1:
            np_x = scipy.sparse.linalg.spsolve(sci_A, np_b)
        else:
            factorisedsolver = scipy.sparse.linalg.factorized(sci_A)
            np_x = factorisedsolver(np_b)

        x = torch.as_tensor(np_x)
        # Not sure if the following is needed / helpful
        if A.requires_grad or b.requires_grad:
            x.requires_grad = True

        # Save context for backward pass
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        # Recover context
        A, b, x = ctx.saved_tensors

        # Compute gradient with respect to b
        gradb = SparseSolve.apply(A.t(), grad)

        gradAidx = A.indices()
        mgradbselect = -gradb.index_select(0,gradAidx[0,:])
        xselect = x.index_select(0,gradAidx[1,:])
        mgbx = mgradbselect * xselect
        if x.dim() == 1:
            gradAvals = mgbx
        else:
            gradAvals = torch.sum( mgbx, dim=1 )
        gradA = torch.sparse_coo_tensor(gradAidx, gradAvals, A.shape)
        return gradA, gradb

sparsesolve = SparseSolve.apply

row_vec = torch.tensor([0, 0, 1, 2])
col_vec = torch.tensor([0, 2, 1, 2])
val_vec = torch.tensor([3.0, 4.0, 5.0, 6.0],dtype=torch.float64)
A = torch.sparse_coo_tensor(torch.stack((row_vec,col_vec),0), val_vec, (3, 3))
b = torch.ones(3, dtype=torch.float64, requires_grad=False)
A.requires_grad=True
b.requires_grad=True
res = torch.autograd.gradcheck(sparsesolve, [A, b], raise_exception=True)
print(res)

which is based on the program from Differentiable sparse linear solver with cupy backend - “unsupported tensor layout: Sparse” in gradcheck, whose author @tvercaut wrote the program based on your blog and program. I modified the program and limited it to running only on CPU.

The output is

torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-3.7037e-02,  0.0000e+00, -1.3878e-11],
        [-6.6667e-02,  0.0000e+00,  0.0000e+00],
        [-5.5556e-02,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -2.2222e-02,  0.0000e+00],
        [ 0.0000e+00, -4.0000e-02,  0.0000e+00],
        [ 0.0000e+00, -3.3333e-02,  0.0000e+00],
        [ 2.4691e-02,  0.0000e+00, -1.8519e-02],
        [ 4.4444e-02,  0.0000e+00, -3.3333e-02],
        [ 3.7037e-02,  0.0000e+00, -2.7778e-02]], dtype=torch.float64)
analytical:tensor([[-0.0370,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [-0.0556,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0400,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0370,  0.0000, -0.0278]], dtype=torch.float64)

The success of your and @tvercaut's program in passing the gradient check can be attributed to the fact that the sparse matrix A you used is actually a dense matrix. Consequently, the autograd() function computes the gradient for each element.

The derivative formula from your blog is $$\frac{\partial L}{\partial A} = - \frac{\partial L}{\partial b} \otimes x$$ Since the matrix A is sparse, then $\frac{\partial L}{\partial A{ij}}=0$ when $A{ij}=0$, but the results computed by pytorch show it's not true. If I change backward() function into

def backward(ctx, grad):
    A, b, x = ctx.saved_tensors
    gradb = SparseSolve.apply(A.t(), grad)

    if x.ndim == 1:
        gradA = -gradb.reshape(-1,1) @ x.reshape(1,-1)  
    else:
        gradA = -gradb @ x.T 

Then the gradient check is passed. However the gradA is now a dense matrix, which is not consistent to the theoretical result. There is a similar issue #13 without detailed explanation. So I want to ask which gradient is right ? the sparse one or the dense one?

tvercaut commented 6 months ago

As mentioned in #13, in most practical cases involving sparse matrices, you would want the derivative with respect to the non-zero selements only.

There is some further discussion here: https://github.com/pytorch/pytorch/issues/87448

I haven't checked in a while but at the time, pytorch had some issues with gradcheck and sparse matrix operations: https://github.com/pytorch/pytorch/issues/87085 In any case, you probably need to pass check_sparse_nnz=True to gradcheck if you are indeed using the standard sparse ops semantics.

Finally, if you want, we consolidate the pure python functions above here: https://github.com/cai4cai/torchsparsegradutils

zhf-0 commented 6 months ago

Thank you very much for your reply.

So, based on the contents from links, the SparseSolver failed to pass the gradient check because torch.autograd.gradcheck function can't deal with sparse matrix properly. Am I right?

The version of pytorch I'm using is 2.3.0 which deprecated the parameter check_sparse_nnz=True. According to pytorch/pytorch#87085, I use masked=True to check the gradient, but the output is

check SparseSolve grad
Traceback (most recent call last):
  File "/home/project/ai4solver/opt_p/sparseLU.py", line 253, in <module>
    CheckGrad(Aref,bref)
  File "/home/project/ai4solver/opt_p/sparseLU.py", line 233, in CheckGrad
    res = torch.autograd.gradcheck(sparsesolve, [A, b], masked=True, raise_exception=True)
  File "/home/software/miniconda/install/envs/sparse/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2049, in gradcheck
    return _gradcheck_helper(**args)
  File "/home/software/miniconda/install/envs/sparse/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 2108, in _gradcheck_helper
    _test_undefined_backward_mode(func, outputs, tupled_inputs)
  File "/home/software/miniconda/install/envs/sparse/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1353, in _test_undefined_backward_mode
    return all(check_undefined_grad_support(output) for output in outputs_to_check)
  File "/home/software/miniconda/install/envs/sparse/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1353, in <genexpr>
    return all(check_undefined_grad_support(output) for output in outputs_to_check)
  File "/home/software/miniconda/install/envs/sparse/lib/python3.10/site-packages/torch/autograd/gradcheck.py", line 1321, in check_undefined_grad_support
    if (gi is not None) and (not gi.eq(0).all()):
NotImplementedError: Could not run 'aten::eq.Scalar' with arguments from the 'SparseCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::eq.Scalar' is only available for these backends: [CPU, CUDA, HIP, MPS, IPU, XPU, HPU, VE, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, Meta, FPGA, ORT, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, QuantizedMeta, CustomRNGKeyId, MkldnnCPU, SparseCsrCPU, SparseCsrCUDA, SparseCsrHIP, SparseCsrMPS, SparseCsrIPU, SparseCsrXPU, SparseCsrHPU, SparseCsrVE, SparseCsrMTIA, SparseCsrPrivateUse1, SparseCsrPrivateUse2, SparseCsrPrivateUse3, SparseCsrMeta, NestedTensorCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Undefined: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
CPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCPU.cpp:31419 [kernel]
CUDA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
HIP: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MPS: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
IPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
XPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
HPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
VE: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MTIA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
PrivateUse1: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
PrivateUse2: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
PrivateUse3: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
Meta: registered at /dev/null:241 [kernel]
FPGA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
ORT: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
Vulkan: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
Metal: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedCPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterQuantizedCPU.cpp:951 [kernel]
QuantizedCUDA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedHIP: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedMPS: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedIPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedXPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedHPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedVE: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedMTIA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
QuantizedMeta: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
CustomRNGKeyId: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
MkldnnCPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrCPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrCUDA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrHIP: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrMPS: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrIPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrXPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrHPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrVE: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrMTIA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
SparseCsrMeta: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21612 [default backend kernel]
NestedTensorCPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/build/aten/src/ATen/RegisterNestedTensorCPU.cpp:775 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/functorch/DynamicLayer.cpp:497 [backend fallback]
Functionalize: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
Conjugate: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradVE: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradMTIA: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/VariableType_0.cpp:17438 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/torch/csrc/autograd/generated/TraceType_0.cpp:16910 [kernel]
AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp:320 [kernel]
BatchedNestedTensor: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/LegacyBatchingRegistrations.cpp:1079 [kernel]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/functorch/TensorWrapper.cpp:202 [backend fallback]
PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/functorch/DynamicLayer.cpp:493 [backend fallback]
PreDispatch: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1712608958871/work/aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]

which is a total mess. I wonder whether there is a way to check the gradient of the sparse matrix.

Finally, if you want, we consolidate the pure python functions above here: https://github.com/cai4cai/torchsparsegradutils

Well, thank you for your advice, I will try the program in this repository.