Open cicirori opened 2 years ago
This zip contains the hlo dump results with GC enabled/disabled.
The number of model substructure cycles in the run that generated this dump was set to 4 instead of 64. But the conclusion is consistent, turning on GC on the GPU causes the memory to increase on this case.
Hmm, is there a way for you to verify the peak memory usage and check if that reduced with gc?
I dump the HLO for with checkpoint and without checkpoint case, will try to find someone to take a look.
I talked with Parker who is the author of the optimization_barrier
HLO, I think my implementation of the optimzation_barrier
has some flaw. I was doing
x1 = layer0.fwd(x0)
(x1,x0) = opt_barrier(x1, x0)
x2 = layer1.fwd(x1)
(x2,x1) = opt_barrier(x2, x1)
x3 = layer2.fwd(x2)
(x3,x2) = opt_barrier(x3, x2)
...
grad2 = layer2.bwd(x2, grad3)
grad1 = layer1.bwd(x1, grad2)
grad0 = layer0.bwd(x0, grad1)
which does gurantee that repeated computation in the backward wait until the corresponding fwd function to finish. However there is nothing prevented repeated computation got moved right after the correponding fwd.
for example it can be
x1 = layer0.fwd(x0)
(x1,x0) = opt_barrier(x1, x0)
layer0.repeated_computation_for_backward
....
grad1 = layer1.bwd(x1, grad2)
grad0 = layer0.remaining_bwd(x0, grad1)
....
which can be unideal. I should really do it like
x1 = layer0.fwd(x0)
x2 = layer1.fwd(x1)
x3 = layer2.fwd(x2)
...
x2, grad3 = opt_barrier(x2, grad3)
grad2 = layer2.bwd(x2, grad3)
x1, grad2 = opt_barrier(x1, grad2)
grad1 = layer1.bwd(x1, grad2)
x0, grad1 = opt_barrier(x0, grad1)
grad0 = layer0.bwd(x0, grad1)
which will fully guarantee the execution order. Not sure how much it will impact the real memory usage but I will try to implement this soon. FYI @ronghanghu
In another word, instead of binding the input and output of the fwd, I should bind the grad_input and input before performing the backward.
Another point that parker raised was that because the example is so simple, the XLA compiler is super adversarial. It is trying really hard to use those unused cores for something!
Let me first try to implement the change I propse above and see if that fixed the issue here.
@JackCaoG I see, thanks for the update on this!
@ronghanghu @cicirori Can you guys give https://github.com/pytorch/xla/pull/3721 a try?
Thanks @JackCaoG -- I'll try this out!
❓ Questions and Help
Recently I started testing GC performance on the GPU on the master version of pytorch and torch xla.
Unfortunately in consistent with my previous conclusions(https://github.com/pytorch/xla/issues/3455#issuecomment-1101056839), the current torch xla GC still seems have difficulty achieving the desired results, even with memory gains on some very simple cases.
Simple test program:
Comparing GPU memory usage in four configs:
running command: