pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Remove removable randomness skips #953

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

The jvp and vjp transforms should not change randomness behavior; e.g. dropout under vjp and with regular PyTorch autograd should produce the same values. vmap however does change randomness behavior.

This PR removes a bunch of randomness skips from jvp and vjp only tests and also fixes our implementation of dropout such that it maintains the above property.

Test Plan: