pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Using mark_sharding vs. MpDeviceLoader with input_sharding=xs.ShardingSpec #7854

Open dudulightricks opened 2 months ago

dudulightricks commented 2 months ago

❓ Questions and Help

If we have a few tensors in a batch with different sizes and we use mark_sharding on each of them, we lose something comparing to input_sharding=xs.ShardingSpec in the MpDeviceLoader (which only works for a single size of tensor in the batch)? @JackCaoG

JackCaoG commented 2 months ago

@alanwaketan can you take this one?

alanwaketan commented 2 months ago

Can you pad your tensors?

dudulightricks commented 2 months ago

@alanwaketan Yes I can but I'm asking if it would give me any benefit compared to just mark_sharding on each tensor.

alanwaketan commented 1 month ago

The dataloader will prefetch the data into the device. That's the most outstanding benefits you get by using any data loaders.