Open amirhajibabaei opened 2 years ago
Thanks for feature request @amirhajibabaei !
I think supporting sparse tensors could be tricky. Your example shows that we can't create a COO sparse tensor inside a function which is used within functorch ops like vjp. However, I also wonder how sparse structure is used in your example, can't we just use values (torch.stack(values)
) for vjp without indices ?
On the other hand, a function taking sparse tensor as argument does not work neither:
import torch
import functorch
def foo(spt, x):
return x + spt
i = [[0, 1, 1], [2, 0, 2]]
v = [3.0, 4.0, 5.0]
s = torch.sparse_coo_tensor(i, v, (2, 3))
x = torch.rand(2, 3)
print(foo(s, x))
_func = lambda st: foo(st, x)
result, vjp = functorch.vjp(_func, s)
gives
result, vjp = functorch.vjp(_func, s)
File "/ft/functorch/_src/eager_transforms.py", line 252, in vjp
primals = _wrap_all_tensors(primals, level)
File "/ft/functorch/_src/eager_transforms.py", line 77, in _wrap_all_tensors
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 179, in tree_map
return tree_unflatten([fn(i) for i in flat_args], spec)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_pytree.py", line 179, in <listcomp>
return tree_unflatten([fn(i) for i in flat_args], spec)
File "/ft/functorch/_src/eager_transforms.py", line 73, in _wrap_tensor_for_grad
return _wrap_for_grad(maybe_tensor, level)
RuntimeError: sparse tensors do not have strides
Thank you @vfdev-5 for your prompt response.
I suspected that supporting sparse tensors would be tricky.
I can formulated the issue in another way: using dict
output instead of sparse tensor.
I noticed that what I have in mind can already be implemented.
Therefore feel free to close the issue.
Although I suggest one simple tweak.
Consider the following function which returns a dict containing tensors:
import torch
import functorch
def function(x, species, param):
unique = torch.unique(species)
values = {}
for u in unique:
y = x[species == u]
d = y.norm(dim=1)
p = (d[:, None] - param).pow(2).mul(-0.5).exp().sum(dim=0)
values[int(u)] = p
return values
# VJP:
param = torch.arange(0., 1., 0.2)
species = torch.randint(0, 2, (10,))
x = torch.rand(10, 3)
_func = lambda x: function(x, species, param)
result = _func(x)
result, vjp = functorch.vjp(_func, x)
The following vjp
already works (if cotangents have the same pytree structure as outputs)
cotangents_1 = {k: torch.ones(5) for k in result.keys()}
vjp(cotangents_1)
But the behavior that I want is that if the cotangents dict does not have the same keys as outputs, the result should be zero:
cotangents_2 = {100: torch.ones(5), 200: torch.ones(5)}
vjp(cotangents_2) # ----> expected output: zeros_like_x
# Error message:
RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function. cotangents: {100: *, 200: *}, primal output: {0: *, 1: *}
For a simpler example, I expect the following
vjp({}) # ------> expected output: zeros_like_x
# Error message:
RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function. cotangents: {}, primal output: {0: *, 1: *}
Or, at least:
vjp({0: None, 1: None}) # ------> expected output: zeros_like_x
# Error message:
~/anaconda3/envs/tensor/lib/python3.8/site-packages/torch/autograd/__init__.py in _make_grads(outputs, grads, is_grads_batched)
65 if out.requires_grad:
66 if out.numel() != 1:
---> 67 raise RuntimeError("grad can be implicitly created only for scalar outputs")
68 new_grads.append(torch.ones_like(out, memory_format=torch.preserve_format))
69 else:
RuntimeError: grad can be implicitly created only for scalar outputs
The feature that I suggest is: Introducing a new keyword strict
(which has the default value of True
) in matching the pytree structure of cotangents and outputs of the function. When strict=False
, if the cotangent for a component of the output is not found, it is assumed to be zero.
Although this is not urgent since the user can manually match the keys.
@amirhajibabaei if I understand correctly, dict
output with vjp returning 0 for inexisting key is a way to represent sparse data right ? Looks interesting.
My concern about strict
argument is that it looks like to be clearly defined for dicts but its behaviour is unclear to me if output type is different: list, tensor etc. How do you think it should behave ?
@vfdev-5 yes, that's right.
I do understand the complications of generalization. I think it is not unreasonable to ask the user to mach the structure. But at least the user should have the option of passing None
instead of large chunk of zeros for eliminating the unnecessary computational cost.
If the strict
was to be implemented, I assume the following behaviors with strict=False
could be expected:
zip
function in python 3.10: zip([1,2], [1, 2, 3], strict=False)
Hello, Thank you for a great repository. I was wondering is it possible to add simple support for sparse tensors? An example code is shown below:
Currently the following error is raised: