Very basic example of torch_xla2 training that highlights the sharding, replication, and JIT utilities.
I intended this to be a dead-simple example that a user could run without burning through their Colab quota, but I ran out of Colab quota myself while developing this. I ended up finishing the work and generating the output on a v4-8. We should adapt this to Colab in the future to make it more accessible, though.
Minor change: return a Tensor from shard_input instead of an Array :facepalm:
Very basic example of
torch_xla2
training that highlights the sharding, replication, and JIT utilities.I intended this to be a dead-simple example that a user could run without burning through their Colab quota, but I ran out of Colab quota myself while developing this. I ended up finishing the work and generating the output on a v4-8. We should adapt this to Colab in the future to make it more accessible, though.
Minor change: return a Tensor from
shard_input
instead of anArray
:facepalm: