Closed Kaushal28 closed 3 years ago
If your function involves solely data preparation, you can roll it into your dataloader's collater. Have a look at pytorch docs and search for collate_fn
in order to see how to use it.
Are you using torch_xla's parallel loaders as here? If so, you don't need to send batches to device explicitly as
X, y = batch[0].to(device), batch[1].to(device) # device is xla
To wrap up, assuming the cutmix
function is completely separate from the model, I suggest you
pl.MpDeviceLoader
to wrap up your dataloaderfor best results.
@Kaushal28 does this work for you? Is the issue good to close?
closing, please re-open if needed.
I got existing code to train EfficientNet using PyTorch which contains custom augmentations like CutMix, MixUp etc. in my training loop. This runs perfectly on GPU. Now I want to change my code such that it can run on TPUs.
I've made required changes to run my code on 8 TPU cores using PyTorch XLA but it's runs very slow when I use custom augmentations in training loop (even slower than GPU). When I remove them it runs significantly faster. So I think I have to make changes in my augmentation functions as well.
Here is my training loop.
And here is my complete
cutmix
function, which causes issues:Whenever I'm creating tensors, I'm moving them to xla device, but still this slows down the training loop on TPUs.
So my question is how can I write pure python functions (here is
cutmix
is pure python function which just does some processing with image tensors) which can efficiently run on TPUs? What changes should I make here? Am I supposed to create all new variables on "xla" device?EDIT: I tried converting everything to tensors (with xla device) in
cutmix
function, but still no speed gain.Thanks.