pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

UserWarning: There is a performance drop because we have not yet implemented the batching rule for `aten::greater_equal.Scalar` #1080

Closed carbocation closed 1 year ago

carbocation commented 1 year ago

🚀 The feature, motivation and pitch

In pytorch 1.14.0.dev20221128 when using vmap with greater_equal, I get the following warning:

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::greater_equal.Scalar. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.) idx = torch.argmax(torch.greater_equal(torch.cumsum(input, dim=0), quantile).to(torch.int))

Alternatives

The alternative is to experience a performance drop. (The computation does complete correctly.)

Additional context

No response

cc @zou3519 @Chillee @samdow @soumith

jiayisunx commented 1 year ago

@carbocation , could you please provide a example to reproduce this issue?

zou3519 commented 1 year ago

Thanks for the issue, we'll prioritize it.

kshitij12345 commented 1 year ago

Fixed in https://github.com/pytorch/pytorch/pull/91324