pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Performance drop because of not yet implemented batching rule. #1069

Closed elientumba2019 closed 1 year ago

elientumba2019 commented 1 year ago

Hello, @zou3519 , @samdow

I get the following warning when i try to compute jacobians of my model.

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::bitwisexor.Tensor. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /opt/conda/conda-bld/pytorch_1666642991888/work/aten/src/ATen/functorch/BatchedFallback.cpp:82.)

The above error is triggered by the code below.

`def spatial_hash(coords: List[Tensor]) -> Tensor: PRIMES = (1, 2654435761, 805459861, 3674653429)

assert len(coords) <= len(PRIMES), "Add more PRIMES!"

if len(coords) == 1:
    i =  (coords[0] ^ PRIMES[1])# torch.logical_xor(coords[0], torch.tensor(PRIMES[1]).cuda()) # (coords[0] ^ PRIMES[1])
else:
    i = coords[0] ^ PRIMES[0] # torch.logical_xor(coords[0], torch.tensor(PRIMES[0]).cuda()) # coords[0] ^ PRIMES[0]
    for c, p in zip(coords[1:], torch.tensor(PRIMES[1:]).cuda()):
        i  ^= c * p # torch.logical_xor(i, c * p) # ^= c * p
return i`

Are there plans to implemented batching for the above operations in the near future ?

Thank you

zou3519 commented 1 year ago

This shouldn't be difficult to do, we can prioritize it

kshitij12345 commented 1 year ago

This is also true for other bitwise ops. What concerns me is that why test_op_has_batch_rule does not catch it 🤔 ?

import torch
import functorch

for bitwise_op in ['bitwise_xor_', 'bitwise_or_', 'bitwise_and_', 'bitwise_not_', 'bitwise_left_shift_', 'bitwise_right_shift_']:

    x = torch.zeros(3, 3, dtype=torch.long)

    def fn(x):
        op = getattr(x, bitwise_op)
        if bitwise_op == 'bitwise_not_':
            op()
        else:
            op(1)
        return x

    t = functorch.vmap(fn)(x)