The jvp and vjp transforms should not change randomness behavior; e.g.
dropout under vjp and with regular PyTorch autograd should produce the
same values. vmap however does change randomness behavior.
This PR removes a bunch of randomness skips from jvp and vjp only tests
and also fixes our implementation of dropout such that it maintains the
above property.
The jvp and vjp transforms should not change randomness behavior; e.g. dropout under vjp and with regular PyTorch autograd should produce the same values. vmap however does change randomness behavior.
This PR removes a bunch of randomness skips from jvp and vjp only tests and also fixes our implementation of dropout such that it maintains the above property.
Test Plan: