pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Audit CompositeImplicitAutograd ops that do not have a batching rule, add them to BatchRulesDecomposition #1087

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

This is low hanging fruit: there are a number of CompositeImplicitAutograd ops that do not have a batching rule. We should just add all of them to BatchRulesDecomposition. Should be easy to detect using testing similar to what @srossross did in https://github.com/pytorch/pytorch/pull/89465/files

Here are a couple of things to be careful of:

kshitij12345 commented 1 year ago

Fixed by https://github.com/pytorch/pytorch/pull/91367