pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

Inconsistency between xla/examples/train_resnet_base.py and docs #7635

Open davidaknowles opened 2 months ago

davidaknowles commented 2 months ago

📚 Documentation

This isn't necessarily an issue with the documentation, but an inconsistency between the documentation and the simplest Pytorch XLA example. The docs say that the one key change to a standard training loop (for single device use) is adding xm.mark_step(), but train_resnet_base.py doesn't have (and just has xm.wait_device_ops() after all all epochs are complete).

My understanding is that xm.mark_step() isn't necessary if we're not directly accessing any state on the TPU, which is why train_resnet_base.py doesn't use it and works around it via xm.add_step_closure. I assume the latter is actually preferred, but either way it would be helpful for folks getting started if there wasn't a confusing inconsistency like this for the simplest setting.

@JackCaoG I think this is your wheelhouse? Thanks for any clarification.

JackCaoG commented 2 months ago

so the short answer is that xm.mark_step is being added in the ParallelLoader for each batch, check my video which actually cover this topic https://youtu.be/LK3A3vjo-KQ?si=1TeK41p3AKszK8ou&t=737.

It is not a good example through, I think I should explictly add torch_xla.sync there to make it more clear.

JackCaoG commented 2 months ago

I am going to rewrite the model using the new eager mode + compile soon too.

davidaknowles commented 2 months ago

Ah I see, hadn't thought of ParallelLoader doing that work. I understood xm.optimizer_step would be doing that in the DDP case but this clears it up for the simple base case. Thanks! BTW I think for beginners it would be greats to have a do's and don'ts list somewhere. e.g. DO use add_step_closure to print metrics during training, DON'T access tensor shapes or call item() unless absolutely necessary, DON'T have varying tensor sizes etc. I was also having issues with having nan's in some data tensors (to represent missing) which were supposedly masked out (fine on GPU) but was giving me core dumps, so maybe using nans in that way is another DON'T.

JackCaoG commented 2 months ago

Here is my first attempt to make it more clear, https://github.com/pytorch/xla/pull/7642. This will only works if you use nightly right now but 2.4 release should also work when it is out later this month..

davidaknowles commented 2 months ago

Thanks, will take a look.

As a minor aside: is there a best practice for running a validation loop? I normally just use a flag that turns off the train specific stuff (loss.backwards() etc) but when I turn off xm.optimizer_step in the loop (using XLA DDP) something funky happens because I almost immediately run out of device memory.

JackCaoG commented 2 months ago

optimizer_step won't call the mark_step for you unless you set barrier=True. I think you can always manually call mark_step or torch_xla.sync after your step fn. If you look at implementation of torch_xla.experimental.compile, it pretty much just calls mark_step before and after the function to make sure we compile the exact region you specified. More debugging details can be found in the video above.