pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.65k stars 22.8k forks source link

Broadcasting for torch.cross #39656

Open tachukao opened 4 years ago

tachukao commented 4 years ago

Cross products are useful in a number of situations (e.g., calculations of quaternion products), and it would be great if we could broadcast with torch.cross

cc @mruberry @rgommers @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi

mruberry commented 4 years ago

Our torch.cross behavior is actually inconsistent with numpy.cross at the moment, too. It'd be nice to get a PR making them the same but we'll need to deprecated torch.cross's current behavior, first.

akarshkumar0101 commented 4 years ago

Are there any updates on adding broadcasting support within torch.cross?

mruberry commented 4 years ago

Are there any updates on adding broadcasting support within torch.cross?

No. We do plan a deprecation review for the next 1.8 release, however, which I expect would include torch.cross.

svkatta commented 3 years ago

Are there any updates on adding broadcasting support within torch.cross? waiting for this

can just broadcast tensor using torch.broadcast_tensors(a, b) and use torch.cross is this equivalent ??

asmeurer commented 9 months ago

This is relevant for array API support. The 2022 version of the standard added broadcasting to cross https://data-apis.org/array-api/latest/extensions/generated/array_api.linalg.cross.html#array_api.linalg.cross.

PyTorch does now support broadcasting for cross:

>>> x1 = torch.ones((2, 3))
>>> x2 = torch.ones((1, 3))
>>> torch.linalg.cross(x1, x2)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

(and this is documented). But broadcasting doesn't work when it would add new dimensions:

>>> x1 = torch.ones((2, 3))
>>> x2 = torch.ones((3,))
>>> torch.linalg.cross(x1, x2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: linalg.cross: inputs must have the same number of dimensions.

Note that the logic isn't quite the same as calling broadcast_tensors(a, b). The cross-ed dimension should not be broadcasted, so the logic should be more like

if not (a.shape[dim] == b.shape[dim] == 3):
     raise RuntimeError(...)
a, b = broadcast_tensors(a, b)