pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Speed up functorch tests in CI #1028

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

Here's roughly how the vmap testing works (pseudocode):

all_permutations = create_all_batched_inputs()
for batched_inputs in all_permutations:
  result = vmap(op)(batched_inputs)
  expected = for_loop_over_op(op, batched_inputs)
  assert torch.allclose(result, expected)

There are two things that could be optimized:

  1. for_loop_over_op(op, batched_inputs). Instead of running a for-loop over the op with slices of batched_inputs, we can just run the op on the original input and expand the result.
  2. expected is the SAME across all iterations of the for loop in the example above. So we only need to compute it once.

The end-state should look something like:

all_permutations = create_all_batched_inputs()
expected = expand(op(inputs))
for batched_inputs in all_permutations:
  result = vmap(op)(batched_inputs)
  assert torch.allclose(result, expected)

This will probably save us something like 50% of runtime on our vmap tests.