pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

(fix CI) Add batch rule for split.sizes #952

Closed samdow closed 1 year ago

samdow commented 1 year ago

All the errors I'm seeing on CI are because split.sizes doesn't have a batch rule. This uses decompositions to register a batching rule for split.sizes