Open davidaknowles opened 4 months ago
so the short answer is that xm.mark_step
is being added in the ParallelLoader
for each batch, check my video which actually cover this topic https://youtu.be/LK3A3vjo-KQ?si=1TeK41p3AKszK8ou&t=737.
It is not a good example through, I think I should explictly add torch_xla.sync
there to make it more clear.
I am going to rewrite the model using the new eager mode + compile soon too.
Ah I see, hadn't thought of ParallelLoader
doing that work. I understood xm.optimizer_step
would be doing that in the DDP case but this clears it up for the simple base case. Thanks!
BTW I think for beginners it would be greats to have a do's and don'ts list somewhere. e.g. DO use add_step_closure
to print metrics during training, DON'T access tensor shapes or call item() unless absolutely necessary, DON'T have varying tensor sizes etc. I was also having issues with having nan's in some data tensors (to represent missing) which were supposedly masked out (fine on GPU) but was giving me core dumps, so maybe using nans in that way is another DON'T.
Here is my first attempt to make it more clear, https://github.com/pytorch/xla/pull/7642. This will only works if you use nightly right now but 2.4 release should also work when it is out later this month..
Thanks, will take a look.
As a minor aside: is there a best practice for running a validation loop? I normally just use a flag that turns off the train specific stuff (loss.backwards()
etc) but when I turn off xm.optimizer_step
in the loop (using XLA DDP) something funky happens because I almost immediately run out of device memory.
optimizer_step
won't call the mark_step
for you unless you set barrier=True
. I think you can always manually call mark_step
or torch_xla.sync
after your step fn. If you look at implementation of torch_xla.experimental.compile
, it pretty much just calls mark_step
before and after the function to make sure we compile the exact region you specified. More debugging details can be found in the video above.
📚 Documentation
This isn't necessarily an issue with the documentation, but an inconsistency between the documentation and the simplest Pytorch XLA example. The docs say that the one key change to a standard training loop (for single device use) is adding
xm.mark_step()
, buttrain_resnet_base.py
doesn't have (and just hasxm.wait_device_ops()
after all all epochs are complete).My understanding is that
xm.mark_step()
isn't necessary if we're not directly accessing any state on the TPU, which is whytrain_resnet_base.py
doesn't use it and works around it viaxm.add_step_closure
. I assume the latter is actually preferred, but either way it would be helpful for folks getting started if there wasn't a confusing inconsistency like this for the simplest setting.@JackCaoG I think this is your wheelhouse? Thanks for any clarification.