pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Rematerialization on arbitrary graphs, mincut #925

Closed yushangdi closed 1 year ago

yushangdi commented 2 years ago

The algorithm is in functorch/_src/remat_utils_mincut.py

You can run benchmarks by

python benchmarks/benchmark_remat_fullgraphs.py --isolate --devices='cuda' -k 'timm_nfnet'

You can also use the -k flag to append other models to run. If no -k models are specified, it will run all models (some will fail due to more than 1 graph in the model).

If you only want to see how much memory will be reduced by the mincut optimization, WITHOUT actually benchmarking the performance, you can use the --info flag.

yushangdi commented 1 year ago

The algorithm is in functorch/_src/remat_utils_mincut.py.

benchmarks/benchmark_remat_examples.py : benchmarks on some hand-crafted small examples benchmarks/benchmark_remat_fullgraphs.py : benchmarks on full graphs of torchbench models benchmarks/benchmark_remat_torchbench.py : benchmarks on dumped forward and backward torchbench models benchmarks/benchmark_remat_utils.py : utilities for benchmarking benchmarks/torchbench_utils.py : utilities for loading torchbench models

test/test_remat_mincut.py : unit tests for the algorithm test/test_remat_torchbench.py: tests for the algorithm on torchbench forward and backward graphs

yushangdi commented 1 year ago

migrated to https://github.com/pytorch/pytorch/pull/82143