Closed iliTheFallen closed 1 year ago
torch.all might be a bit tricky. To clarify the semantics: what do you expect the following to return?
x = torch.tensor([[True, False], [True, True]])
result = vmap(torch.all)(x)
is result tensor([False, True])
, or is it tensor(True)
?
Neither of them. I want it to return "torch.Tensor(False)". I don't need to specify a dimension when I call "torch.all"
Neither of them. I want it to return "torch.Tensor(False)".
Got it. Sorry I miswrote when I wrote tensor(True)
. Do you have some more context behind the use case and why you would want it to return tensor(False)
?
For some more background, here is the problem: vmap(f)(xs)'s semantics are, in the absence of side effects, it should be equivalent to torch.stack([f(x) for x in xs])
.
Using that mental model on the example,
x = torch.tensor([[True, False], [True, True]])
result = vmap(torch.all)(x)
we get that the result should be tensor([False, True])
.
If we want this to return tensor(False)
, we should add some new special operations that are allowed to reduce along the dimension being vmapped.
My use case is as simple as follows:
Suppose that you have a tensor "inpData" of size (B, N+1, P), where B=batchSize, N=# of neighbouring locations, and P=# of features each location has.
Suppose that slices which correspond to a set of batch items is completely filled with value of "torch.nan", i.e.:
S = {B(i) | torch.all(torch.isnan(B(i))) = True)},
where B(i) = i^{th}
batch item. Thus, what I would like to do in terms of coding goes as:
# inpData of size (B, N+1, P)
nanIdx = torch.isnan(inpData)
if nanIdx.all():
return None
# Find empty slices only along specified dimension
f = functorch.vmap(torch.all, in_dims=dim)
S = f(nanIdx) # (B)
return S.logical_not()
Neither of them. I want it to return "torch.Tensor(False)".
Got it. Sorry I miswrote when I wrote
tensor(True)
. Do you have some more context behind the use case and why you would want it to returntensor(False)
?For some more background, here is the problem: vmap(f)(xs)'s semantics are, in the absence of side effects, it should be equivalent to
torch.stack([f(x) for x in xs])
.Using that mental model on the example,
x = torch.tensor([[True, False], [True, True]]) result = vmap(torch.all)(x)
we get that the result should be
tensor([False, True])
.If we want this to return
tensor(False)
, we should add some new special operations that are allowed to reduce along the dimension being vmapped.
Hi, I think it makes more sense to keep the semantics follow torch.stack([f(x) for x in xs])
, as each data along batch dim is totally independent to each other. Will this be supported at some point (even with some warnings)? Thanks!
Hi, I think it makes more sense to keep the semantics follow torch.stack([f(x) for x in xs]), as each data along batch dim is totally independent to each other.
It's not too difficult to add following those semantics, we can prioritize it
When I vmap "torch.all" function, I get the following:
/tmp/ipykernel_39088/2496106444.py:7: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::all. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:85.) f = functorch.vmap(torch.all, in_dims=1)
Is it possible to make it available for v1.12?