spcl / sten

Sparsity support for PyTorch
31 stars 3 forks source link

Function `SparseOperatorDispatcherBackward` returned an invalid gradient #1

Closed ficstamas closed 1 year ago

ficstamas commented 1 year ago

Description

I attempted to create some wrappers to simplify the use of sten and expand its functionality. Despite my efforts to work primarily with sparse tensors (which may be the source of the problem), during the backward pass, PyTorch tries to match the shape of the gradient with the dummy_shape which fails for obvious reasons.

Packages

numpy==1.23.5
scipy==1.9.3
torch==1.12.1
sten==0.0.3 

Reproduction

It is straight forward to reproduce the error. Based on the SparseMLP example, I made a minimal "working" script. The only modification I made was to the out_fmt of sparse_op. Afterwards, I just implemented every forward method, and only one of the many backward methods because it fails before the next one is even called.

In SparseLinear.forward I just changed the torch.Tensors to a sparse tensor:

sparse_op = sten.sparsified_op(
    orig_op=torch.nn.functional.linear,
    out_fmt=tuple(
        [
            (
                sten.RandomFractionSparsifier(self.weight_sparsity),
                sten.CscTensor,
                sten.RandomFractionSparsifier(self.weight_sparsity),
                sten.CscTensor,
            )
        ]
    ),
    grad_out_fmt=tuple([(sten.KeepAll(), torch.Tensor, sten.KeepAll(), torch.Tensor)]),
)

The whole script:

import scipy
import sten
import torch
from torch.optim import AdamW
import random

@sten.register_sparsifier_implementation(sparsifer=sten.RandomFractionSparsifier, inp=sten.CscTensor, out=sten.CscTensor)
def torch_csc_to_csc_random_fraction(sparsifier, tensor, grad_fmt=None):
    return sten.SparseTensorWrapper.wrapped_from_dense(
        sten.CscTensor(
            scipy.sparse.csc_matrix(sten.random_mask_sparsify(tensor.wrapped_tensor.to_dense(), sparsifier.fraction))
        ),
        tensor,
        grad_fmt,
    )

@sten.register_bwd_op_impl(
    operator=torch.nn.functional.mse_loss,
    grad_out=(torch.Tensor,),
    grad_inp=(
        (sten.KeepAll, torch.Tensor),
        (sten.KeepAll, torch.Tensor),
    ),
    inp=(sten.CscTensor, torch.Tensor),
)
def sparse_to_sparse_backward_mse(ctx, grad_outputs, input_sparsifiers):
    input_1, input_2 = ctx.saved_tensors
    # I don't care about the right gradient atm, it fails due to incorrect shape
    return input_1.wrapped_tensor.to_dense(), input_2

@sten.register_fwd_op_impl(
    operator=torch.nn.functional.mse_loss,
    inp=(sten.CscTensor, torch.Tensor),
    out=[(sten.KeepAll, torch.Tensor)],
)
def fwd_mse_loss(ctx, inputs, output_sparsifiers):
    [
        inp,
        inp2,
    ] = inputs
    ctx.save_for_backward(inp, inp2)
    return torch.nn.functional.mse_loss(inp.wrapped_tensor.to_dense(), inp2)

@sten.register_fwd_op_impl(
    operator=torch.nn.functional.relu,
    inp=(sten.CscTensor,),
    out=[(sten.KeepAll, torch.Tensor)],
)
def fwd_relu_activation(ctx, inputs, output_sparsifiers):
    [
        inp,
    ] = inputs
    ctx.save_for_backward(inp)
    return torch.nn.functional.relu(inp.wrapped_tensor.to_dense())

@sten.register_fwd_op_impl(
    operator=torch.nn.functional.linear,
    inp=(torch.Tensor, sten.CscTensor, torch.Tensor),
    out=[(sten.RandomFractionSparsifier, sten.CscTensor)],
)
def torch_linear_fwd_impl(ctx, inputs, output_sparsifiers):
    input_, weights, bias = inputs
    ctx.save_for_backward(input_, weights, bias)
    output_ = torch.from_numpy(input_.numpy() @ weights.wrapped_tensor.data.T) + bias
    return sten.SparseTensorWrapper.wrapped_from_dense(
        sten.CscTensor(scipy.sparse.csc_matrix(output_.numpy())),
        output_,
        weights.grad_fmt,
    )

class SparseLinear(torch.nn.Module):
    def __init__(self, input_features, output_features, weight_sparsity):
        super().__init__()
        self.weight_sparsity = weight_sparsity
        self.weight = sten.SparseParameterWrapper(
            sten.random_fraction_sparsifier_dense_csc(
                sten.RandomFractionSparsifier(self.weight_sparsity),
                torch.randn(output_features, input_features),
                (
                    sten.KeepAll(),
                    torch.Tensor,
                    sten.RandomFractionSparsifier(self.weight_sparsity),
                    sten.CscTensor,
                ),
            )
        )
        self.bias = torch.nn.Parameter(torch.rand(output_features))

    def forward(self, input):
        sparse_op = sten.sparsified_op(
            orig_op=torch.nn.functional.linear,
            out_fmt=tuple(
                [
                    (
                        sten.RandomFractionSparsifier(self.weight_sparsity),
                        sten.CscTensor,
                        sten.RandomFractionSparsifier(self.weight_sparsity),
                        sten.CscTensor,
                    )
                ]
            ),
            grad_out_fmt=tuple([(sten.KeepAll(), torch.Tensor, sten.KeepAll(), torch.Tensor)]),
        )
        return sparse_op(input, self.weight, self.bias)

class SparseMLP(torch.nn.Module):
    def __init__(self, channel_sizes, weight_sparsity):
        super().__init__()
        self.layers = torch.nn.Sequential()
        in_out_pairs = list(zip(channel_sizes[:-1], channel_sizes[1:]))
        for idx, (in_channels, out_channels) in enumerate(in_out_pairs):
            if idx != 0:
                self.layers.append(torch.nn.ReLU())
            self.layers.append(SparseLinear(in_channels, out_channels, weight_sparsity))

    def forward(self, input):
        return self.layers(input)

torch.random.manual_seed(0)
random.seed(0)

model = SparseMLP([50, 40, 30, 20, 30, 10], 0.8)
optimizer = AdamW(model.parameters(), lr=5e-5)
optimizer.zero_grad()
loss_fct = torch.nn.MSELoss()

target = torch.randn(15, 10)

output = model(torch.randn(15, 50))
loss = loss_fct(output, target)

loss.backward()

Output

It should give the following output:

anaconda3\envs\sparsity\lib\site-packages\torch\distributed\distributed_c10d.py:181: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn(
anaconda3\envs\sparsity\lib\site-packages\sten\sten.py:416: DispatchError: Semantics of torch.nn.functional.relu is unknown, trying to discover it by executing...
  warnings.warn(
anaconda3\envs\sparsity\lib\site-packages\sten\sten.py:416: DispatchError: Semantics of torch.nn.functional.mse_loss is unknown, trying to discover it by executing...
  warnings.warn(
Traceback (most recent call last):
  File "sparsity\demo.py", line 136, in <module>
    loss.backward()
  File "anaconda3\envs\sparsity\lib\site-packages\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "anaconda3\envs\sparsity\lib\site-packages\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function SparseOperatorDispatcherBackward returned an invalid gradient at index 0 - got [15, 10] but expected shape compatible with [3, 1, 2, 1, 1, 1, 1, 1, 2, 1]

Process finished with exit code 1

In addition, I discovered that the dummy_shape is recorded on the loss.grad_fn, which is expected as it results in an error after register_bwd_op_impl: torch.nn.functional.mse_loss returns.

Also the PyTorch part of the error message originates from here: https://github.com/pytorch/pytorch/blob/b95e1d76a86b7b66f0946f72ebd33889bfc19e03/torch/csrc/autograd/engine.cpp#L818 metadata.incompatible_shape_error_message(i, grad): https://github.com/pytorch/pytorch/blob/77c2a8a11f7b5164c255b5b49dbc66a3f6533e9d/torch/csrc/autograd/input_metadata.h#L91

Do you think it is a problem with PyTorch or is it something I did incorrectly with STen?

and-ivanov commented 1 year ago

Thank you for the detailed reproducing example.

At the first glance, it looks like the issue is in grad_out_fmt=tuple([(sten.KeepAll(), torch.Tensor, sten.KeepAll(), torch.Tensor)]), inside SparseLinear class. Basically, for each occurrence of sparse tensor in autograd graph, its gradient have to be wrapped inside SparseTensorWrapper. If this is not happening, autograd tries to operate on tensors of different types and breaks because of their shape mismatch. The easiest fix in your example is probably to replace the problematic string with grad_out_fmt=tuple([(sten.KeepAll(), torch.Tensor, sten.KeepAll(), sten.DenseTensor)]), and continue implementing other things from that. DenseTensor as the last argument should automatically wrap torch.Tensor into SparseTensorWrapper but still, keep the same dense data.

I see this API to be suboptimal and try to investigate it further, especially taking into account that our examples provide exactly such a template.

ficstamas commented 1 year ago

Thanks for your quick response. I'll try to implement it in my actual code sometimes early next week, but I believe it should work.