pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap over conv_transpose on CUDA appears incorrect on A100 GPUs #1050

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

Getting things like:

E           AssertionError: Tensor-likes are not close!
E
E           Mismatched elements: 1670 / 2240 (74.6%)
E           Greatest absolute difference: 0.1141815185546875 at index (0, 0, 0, 1, 5, 4) (up to 0.0001 allowed)
E           Greatest relative difference: 0.1814146823933447 at index (0, 0, 1, 5, 2, 4) (up to 0.0001 allowed)

Repro:

pytest test_vmap.py -v -k "conv_transpose"
zou3519 commented 1 year ago

This is expected; TF32 doesn't have the full precision of fp32 (https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere).