pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 480 forks source link

[torch_xla2] Basic DDP example in a notebook #8042

Closed will-cromar closed 1 month ago

will-cromar commented 2 months ago

Very basic example of torch_xla2 training that highlights the sharding, replication, and JIT utilities.

I intended this to be a dead-simple example that a user could run without burning through their Colab quota, but I ran out of Colab quota myself while developing this. I ended up finishing the work and generating the output on a v4-8. We should adapt this to Colab in the future to make it more accessible, though.

Minor change: return a Tensor from shard_input instead of an Array :facepalm: