pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

masked_fill.Scalar batch rule #964

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

Related to https://github.com/pytorch/functorch/issues/957

It's difficult to write a batching rule for masked_fill.Tensor, so I didn't write one for that.