Here's roughly how the vmap testing works (pseudocode):
all_permutations = create_all_batched_inputs()
for batched_inputs in all_permutations:
result = vmap(op)(batched_inputs)
expected = for_loop_over_op(op, batched_inputs)
assert torch.allclose(result, expected)
There are two things that could be optimized:
for_loop_over_op(op, batched_inputs). Instead of running a for-loop over the op with slices of batched_inputs, we can just run the op on the original input and expand the result.
expected is the SAME across all iterations of the for loop in the example above. So we only need to compute it once.
The end-state should look something like:
all_permutations = create_all_batched_inputs()
expected = expand(op(inputs))
for batched_inputs in all_permutations:
result = vmap(op)(batched_inputs)
assert torch.allclose(result, expected)
This will probably save us something like 50% of runtime on our vmap tests.
Here's roughly how the vmap testing works (pseudocode):
There are two things that could be optimized:
for_loop_over_op(op, batched_inputs)
. Instead of running a for-loop over the op with slices of batched_inputs, we can just run the op on the original input and expand the result.expected
is the SAME across all iterations of the for loop in the example above. So we only need to compute it once.The end-state should look something like:
This will probably save us something like 50% of runtime on our vmap tests.