pytorch / nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
BSD 3-Clause "New" or "Revised" License
252 stars 28 forks source link

slice operation on "ragged" dimension #473

Open wenleix opened 2 years ago

wenleix commented 2 years ago

🚀 Feature

Support slice operation on "ragged" dimension

Motivation

In preproc we often wants to operates over variable-width list, such as token ids in text domain, or sparse features in recommendation domain; one common operation is to slice over each list (e.g. only need first k elements). One way is to use Arrow's List type:

>>> import torcharrow as ta
>>> id_list = ta.column([[0, 1, 2, 3], [4, 5, 6, 7, 8], [8, 9]])
>>> id_list
0  [0, 1, 2, 3]
1  [4, 5, 6, 7, 8]
2  [8, 9]
dtype: List(int64), length: 3, null_count: 0

>>> id_list.list.slice(stop=3)
0  [0, 1, 2]
1  [4, 5, 6]
2  [8, 9]
dtype: List(Int64(nullable=True)), length: 3, null_count: 0

I was thinking nested tensor may also work well for this use case (especially when doing preproc after Tensor collate). But looks like slice is not yet supported on ragged dimension?

>>> import torch
>>> a, b, c = torch.arange(4), torch.arange(5) + 4, torch.arange(2) + 8
>>> id_list = torch.nested_tensor([a, b, c])
>>> id_list
nested_tensor([
  tensor([0, 1, 2, 3]),
  tensor([4, 5, 6, 7, 8]),
  tensor([8, 9])
])
>>> id_list[:, :3]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor

Wondering if there is any plan to support this? Thanks!

dracifer commented 2 years ago

looks slicing with fixed first dimension works. torch.topk does not to work either.