pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

support for sparse_coo tensors #655

Open amirhajibabaei opened 2 years ago

amirhajibabaei commented 2 years ago

Hello, Thank you for a great repository. I was wondering is it possible to add simple support for sparse tensors? An example code is shown below:

import torch
import functorch

def function(x, species, param):
    unique = torch.unique(species)
    values = []
    for u in unique:
        y = x[species == u]
        d = y.norm(dim=1)
        p = (d[:, None] - param).pow(2).mul(-0.5).exp().sum(dim=0)
        values.append(p)
    q = torch.sparse_coo_tensor([unique.tolist()], torch.stack(values))
    return q

param = torch.arange(0., 1., 0.2)
species = torch.randint(0, 2, (10,))
x = torch.rand(10, 3)
_func = lambda x: function(x, species, param)
#result = _func(x)
result, vjp = functorch.vjp(_func, x)

Currently the following error is raised:

     11         p = (d[:, None] - param).pow(2).mul(-0.5).exp().sum(dim=0)
     12         values.append(p)
---> 13     q = torch.sparse_coo_tensor([unique.tolist()], torch.stack(values))
     14     return q
     15 

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
vfdev-5 commented 2 years ago

Thanks for feature request @amirhajibabaei !

I think supporting sparse tensors could be tricky. Your example shows that we can't create a COO sparse tensor inside a function which is used within functorch ops like vjp. However, I also wonder how sparse structure is used in your example, can't we just use values (torch.stack(values)) for vjp without indices ?

On the other hand, a function taking sparse tensor as argument does not work neither:

import torch
import functorch

def foo(spt, x):
    return x + spt

i = [[0, 1, 1], [2, 0, 2]]
v = [3.0, 4.0, 5.0]
s = torch.sparse_coo_tensor(i, v, (2, 3))
x = torch.rand(2, 3)

print(foo(s, x))

_func = lambda st: foo(st, x)
result, vjp = functorch.vjp(_func, s)

gives

    result, vjp = functorch.vjp(_func, s)
  File "/ft/functorch/_src/eager_transforms.py", line 252, in vjp
    primals = _wrap_all_tensors(primals, level)
  File "/ft/functorch/_src/eager_transforms.py", line 77, in _wrap_all_tensors
    return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 179, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 179, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/ft/functorch/_src/eager_transforms.py", line 73, in _wrap_tensor_for_grad
    return _wrap_for_grad(maybe_tensor, level)
RuntimeError: sparse tensors do not have strides
amirhajibabaei commented 2 years ago

Thank you @vfdev-5 for your prompt response. I suspected that supporting sparse tensors would be tricky. I can formulated the issue in another way: using dict output instead of sparse tensor. I noticed that what I have in mind can already be implemented. Therefore feel free to close the issue. Although I suggest one simple tweak. Consider the following function which returns a dict containing tensors:

import torch
import functorch

def function(x, species, param):
    unique = torch.unique(species)
    values = {}
    for u in unique:
        y = x[species == u]
        d = y.norm(dim=1)
        p = (d[:, None] - param).pow(2).mul(-0.5).exp().sum(dim=0)
        values[int(u)] = p
    return values

# VJP:
param = torch.arange(0., 1., 0.2)
species = torch.randint(0, 2, (10,))
x = torch.rand(10, 3)
_func = lambda x: function(x, species, param)
result = _func(x)
result, vjp = functorch.vjp(_func, x)

The following vjp already works (if cotangents have the same pytree structure as outputs)

cotangents_1 = {k: torch.ones(5) for k in result.keys()}
vjp(cotangents_1)

But the behavior that I want is that if the cotangents dict does not have the same keys as outputs, the result should be zero:

cotangents_2 = {100: torch.ones(5), 200: torch.ones(5)}
vjp(cotangents_2)   # ---->  expected output:  zeros_like_x

# Error message:
RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function. cotangents: {100: *, 200: *}, primal output: {0: *, 1: *}

For a simpler example, I expect the following

vjp({})    # ------>  expected output:  zeros_like_x

# Error message:
RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function. cotangents: {}, primal output: {0: *, 1: *}

Or, at least:

vjp({0: None, 1: None})     # ------>  expected output:  zeros_like_x

# Error message:
~/anaconda3/envs/tensor/lib/python3.8/site-packages/torch/autograd/__init__.py in _make_grads(outputs, grads, is_grads_batched)
     65             if out.requires_grad:
     66                 if out.numel() != 1:
---> 67                     raise RuntimeError("grad can be implicitly created only for scalar outputs")
     68                 new_grads.append(torch.ones_like(out, memory_format=torch.preserve_format))
     69             else:

RuntimeError: grad can be implicitly created only for scalar outputs

The feature that I suggest is: Introducing a new keyword strict (which has the default value of True) in matching the pytree structure of cotangents and outputs of the function. When strict=False, if the cotangent for a component of the output is not found, it is assumed to be zero. Although this is not urgent since the user can manually match the keys.

vfdev-5 commented 2 years ago

@amirhajibabaei if I understand correctly, dict output with vjp returning 0 for inexisting key is a way to represent sparse data right ? Looks interesting.

My concern about strict argument is that it looks like to be clearly defined for dicts but its behaviour is unclear to me if output type is different: list, tensor etc. How do you think it should behave ?

amirhajibabaei commented 2 years ago

@vfdev-5 yes, that's right.

I do understand the complications of generalization. I think it is not unreasonable to ask the user to mach the structure. But at least the user should have the option of passing None instead of large chunk of zeros for eliminating the unnecessary computational cost.

If the strict was to be implemented, I assume the following behaviors with strict=False could be expected: