Closed elientumba2019 closed 1 year ago
This shouldn't be difficult to do, we can prioritize it
This is also true for other bitwise ops. What concerns me is that why test_op_has_batch_rule
does not catch it 🤔 ?
import torch
import functorch
for bitwise_op in ['bitwise_xor_', 'bitwise_or_', 'bitwise_and_', 'bitwise_not_', 'bitwise_left_shift_', 'bitwise_right_shift_']:
x = torch.zeros(3, 3, dtype=torch.long)
def fn(x):
op = getattr(x, bitwise_op)
if bitwise_op == 'bitwise_not_':
op()
else:
op(1)
return x
t = functorch.vmap(fn)(x)
Hello, @zou3519 , @samdow
I get the following warning when i try to compute jacobians of my model.
UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::bitwisexor.Tensor. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /opt/conda/conda-bld/pytorch_1666642991888/work/aten/src/ATen/functorch/BatchedFallback.cpp:82.)
The above error is triggered by the code below.
`def spatial_hash(coords: List[Tensor]) -> Tensor: PRIMES = (1, 2654435761, 805459861, 3674653429)
assert len(coords) <= len(PRIMES), "Add more PRIMES!"
Are there plans to implemented batching for the above operations in the near future ?
Thank you