Open tachukao opened 4 years ago
Our torch.cross
behavior is actually inconsistent with numpy.cross
at the moment, too. It'd be nice to get a PR making them the same but we'll need to deprecated torch.cross
's current behavior, first.
Are there any updates on adding broadcasting support within torch.cross?
Are there any updates on adding broadcasting support within torch.cross?
No. We do plan a deprecation review for the next 1.8 release, however, which I expect would include torch.cross.
Are there any updates on adding broadcasting support within torch.cross
? waiting for this
can just broadcast tensor using torch.broadcast_tensors(a, b)
and use torch.cross
is this equivalent ??
This is relevant for array API support. The 2022 version of the standard added broadcasting to cross
https://data-apis.org/array-api/latest/extensions/generated/array_api.linalg.cross.html#array_api.linalg.cross.
PyTorch does now support broadcasting for cross
:
>>> x1 = torch.ones((2, 3))
>>> x2 = torch.ones((1, 3))
>>> torch.linalg.cross(x1, x2)
tensor([[0., 0., 0.],
[0., 0., 0.]])
(and this is documented). But broadcasting doesn't work when it would add new dimensions:
>>> x1 = torch.ones((2, 3))
>>> x2 = torch.ones((3,))
>>> torch.linalg.cross(x1, x2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: linalg.cross: inputs must have the same number of dimensions.
Note that the logic isn't quite the same as calling broadcast_tensors(a, b)
. The cross-ed dimension should not be broadcasted, so the logic should be more like
if not (a.shape[dim] == b.shape[dim] == 3):
raise RuntimeError(...)
a, b = broadcast_tensors(a, b)
Cross products are useful in a number of situations (e.g., calculations of quaternion products), and it would be great if we could broadcast with
torch.cross
cc @mruberry @rgommers @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi