Open yaroslavvb opened 1 year ago
Dynamo unrolls loops, and its not particularly fast at tracing through them. This is not only an issue with NumPy, but also when tracing PyTorch.
For NumPy, what we figured that often works best is to just compile the body of the loop. You can also even do some partial unroling, where you compile a few iterations of the loop at a time.
Of course, this is just a workaround for the actual issue, but yeah.
cc @Chillee for compile time. Also, as @Chillee mentioned elsewhere, a host-side control flow operator (e.g. something like jax.lax.fori_loop) could help here.
I'm interested in poking around at this issue. Let me talk with others first though to get a feel of how much effort a possible solution could take.
@zou3519 afori_loop
primitive that was was substituted automatically for Python functions involving for
loops, would resolve the issue
π Describe the bug
I'm converting some numpy code, basically it's this for loop. It takes 30 seconds to compile.
Is it too much to ask to make it faster? :)
Works great otherwise!
for step_idx in range(num_steps): X = hsqrt * np.random.randn(B, d) losses = np.einsum("BD,BD->B", E, X) E -= alpha * np.einsum("BD,B->BD", X, losses) traj[step_idx] = E
Converting the above to a higher-order fori_loop
function automatically isn't too hard, ~with the exception of the line traj[step_idx] = E
. That line alone has side-effects and would probably require recursion. However, the IR passed to Inductor is already recursive. So there would be no benefit there.~ (I forgot that the non-functional stuff is pulled out of Inductor, so this could work too).
One other possible way: after one or two iterations and a fixpoint is reached (which should be quite fast since this is a trace), express the loop as a recursive definition of state. Then compile that recursive definition. This would be a more efficient way of evaluating the loop than the current method which is evaluating the loop fully then evaluating the recursive expressions fully.
(see: cyclic term graph (Ariola & Klop 1996))
An approach that requires less static analysis juice is to require some user annotation. For example, you could annotate a function with a higher order op that instructs Dynamo to never inline it; instead, it must be possible to compile it once and reuse at all call sites. If the loop here can be rewritten to call such a function, it would also resolve your problem. The simplest implementation of this op would not allow side effects on Python.
One other possible way: after one or two iterations and a fixpoint is reached (which should be quite fast since this is a trace) ...
Nevermind, after reasoning this out, it's not possible. That would require projecting traces or some form of symbolic analysis, which is completely against the dynamo architecture.
What is the current state of this?
Does the torch._higher_order_ops.while_loop
in https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66 help with this case?
TorchDynamo still unrolls the loop aggressively. There is no easy workaround here. If it makes sense in your codebase, instead of applying torch.compile
at the very top, you could lift the loop body into a separate function and then apply torch.compile
manually on that lifted function. This might be not very user-friendly.
TorchDynamo still unrolls the loop aggressively. There is no easy workaround here. If it makes sense in your codebase, instead of applying
torch.compile
at the very top, you could lift the loop body into a separate function and then applytorch.compile
manually on that lifted function. This might be not very user-friendly.
I wonder if long term we can do that lifting implicitly during compilation
π Describe the bug
I'm converting some numpy code, basically it's this for loop. It takes 30 seconds to compile.
Is it too much to ask to make it faster? :)
Works great otherwise!
Error logs
No response
Minified repro
Versions
This is on Google Colab and PyTorch nightly
cc @ezyang @anijain2305 @chauhang @penguinwu @oulgen @jamesjwu @aorenste @laithsakka @zou3519 @ydwu4 @bdhirsh @msaroufim @wconstab