pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

aten::all not implemented #1060

Closed iliTheFallen closed 1 year ago

iliTheFallen commented 1 year ago

When I vmap "torch.all" function, I get the following:

/tmp/ipykernel_39088/2496106444.py:7: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::all. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:85.) f = functorch.vmap(torch.all, in_dims=1)

Is it possible to make it available for v1.12?

zou3519 commented 1 year ago

torch.all might be a bit tricky. To clarify the semantics: what do you expect the following to return?

x = torch.tensor([[True, False], [True, True]])
result = vmap(torch.all)(x)

is result tensor([False, True]), or is it tensor(True)?

iliTheFallen commented 1 year ago

Neither of them. I want it to return "torch.Tensor(False)". I don't need to specify a dimension when I call "torch.all"

zou3519 commented 1 year ago

Neither of them. I want it to return "torch.Tensor(False)".

Got it. Sorry I miswrote when I wrote tensor(True). Do you have some more context behind the use case and why you would want it to return tensor(False)?

For some more background, here is the problem: vmap(f)(xs)'s semantics are, in the absence of side effects, it should be equivalent to torch.stack([f(x) for x in xs]).

Using that mental model on the example,

x = torch.tensor([[True, False], [True, True]])
result = vmap(torch.all)(x)

we get that the result should be tensor([False, True]).

If we want this to return tensor(False), we should add some new special operations that are allowed to reduce along the dimension being vmapped.

iliTheFallen commented 1 year ago

My use case is as simple as follows:

Suppose that you have a tensor "inpData" of size (B, N+1, P), where B=batchSize, N=# of neighbouring locations, and P=# of features each location has.

Suppose that slices which correspond to a set of batch items is completely filled with value of "torch.nan", i.e.:

S = {B(i) | torch.all(torch.isnan(B(i))) = True)},

where B(i) = i^{th} batch item. Thus, what I would like to do in terms of coding goes as:

# inpData of size (B, N+1, P)
nanIdx = torch.isnan(inpData)
if nanIdx.all():
    return None
# Find empty slices only along specified dimension 
f = functorch.vmap(torch.all, in_dims=dim)
S = f(nanIdx) # (B)
return S.logical_not()
kxhit commented 1 year ago

Neither of them. I want it to return "torch.Tensor(False)".

Got it. Sorry I miswrote when I wrote tensor(True). Do you have some more context behind the use case and why you would want it to return tensor(False)?

For some more background, here is the problem: vmap(f)(xs)'s semantics are, in the absence of side effects, it should be equivalent to torch.stack([f(x) for x in xs]).

Using that mental model on the example,

x = torch.tensor([[True, False], [True, True]])
result = vmap(torch.all)(x)

we get that the result should be tensor([False, True]).

If we want this to return tensor(False), we should add some new special operations that are allowed to reduce along the dimension being vmapped.

Hi, I think it makes more sense to keep the semantics follow torch.stack([f(x) for x in xs]), as each data along batch dim is totally independent to each other. Will this be supported at some point (even with some warnings)? Thanks!

zou3519 commented 1 year ago

Hi, I think it makes more sense to keep the semantics follow torch.stack([f(x) for x in xs]), as each data along batch dim is totally independent to each other.

It's not too difficult to add following those semantics, we can prioritize it

kshitij12345 commented 1 year ago

Fixed in https://github.com/pytorch/pytorch/pull/91966