pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Small (<5%) regression in functorch_maml_omniglot_cuda model from 0.2.1 to latest #1040

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

On my machine with ~v100~ P100 GPUs, the runtime goes from 286ms to 316ms

To repro:

# setup pytorch/benchmark
git clone https://github.com/pytorch/benchmark
cd benchmark
# this doesn't need to complete successfully -- we just need to install torchbenchmark's basic dependencies.
python setup.py install

python run_benchmark.py functorch
samdow commented 1 year ago

On A100s, seeing 107ms to 111ms ~4% regression

samdow commented 1 year ago

On AWS V100s, I'm seeing 169ms to 174ms which is ~3% regression

zou3519 commented 1 year ago

With V100 on FAIR cluster I see 190ms to 211ms which is roughly 5%, so I can repro your v100 numbers @samdow. These numbers aren't crazy enough to investigate for 1.13, but might be worth looking into if this is overhead or something else