pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 481 forks source link

How to write pure Python function which can be ran on TPUs while using PyTorch-XLA? #2707

Closed Kaushal28 closed 3 years ago

Kaushal28 commented 3 years ago

I got existing code to train EfficientNet using PyTorch which contains custom augmentations like CutMix, MixUp etc. in my training loop. This runs perfectly on GPU. Now I want to change my code such that it can run on TPUs.

I've made required changes to run my code on 8 TPU cores using PyTorch XLA but it's runs very slow when I use custom augmentations in training loop (even slower than GPU). When I remove them it runs significantly faster. So I think I have to make changes in my augmentation functions as well.

Here is my training loop.

def train():
    for batch in train_loader:
        X, y = batch[0].to(device), batch[1].to(device)  # device is xla
        cutmixup_prob = random.random()

        if cutmixup_prob > 0.4:
            X, y, y_shuffled, lam = cutmix(X, y, 0.4)

        # forward pass
        # calc. loss
        # backward pass
        xm.optimizer_step(optimizer)

        # calc. and return accuracy

And here is my complete cutmix function, which causes issues:

# https://www.kaggle.com/c/bengaliai-cv19/discussion/126504
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix(images, targets, alpha):
    device = images.device
    indices = torch.randperm(images.size(0)).to(device)
    shuffled_targets = targets[indices].to(device)

    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    # Cutmix
    images[:, :, bbx1:bbx2, bby1:bby2] = images[indices, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
    return images, targets, shuffled_targets, lam

Whenever I'm creating tensors, I'm moving them to xla device, but still this slows down the training loop on TPUs.

So my question is how can I write pure python functions (here is cutmix is pure python function which just does some processing with image tensors) which can efficiently run on TPUs? What changes should I make here? Am I supposed to create all new variables on "xla" device?

EDIT: I tried converting everything to tensors (with xla device) in cutmix function, but still no speed gain.

Thanks.

taylanbil commented 3 years ago

If your function involves solely data preparation, you can roll it into your dataloader's collater. Have a look at pytorch docs and search for collate_fn in order to see how to use it.

Are you using torch_xla's parallel loaders as here? If so, you don't need to send batches to device explicitly as

        X, y = batch[0].to(device), batch[1].to(device)  # device is xla

To wrap up, assuming the cutmix function is completely separate from the model, I suggest you

  1. roll it into your dataloader
  2. use pl.MpDeviceLoaderto wrap up your dataloader

for best results.

taylanbil commented 3 years ago

@Kaushal28 does this work for you? Is the issue good to close?

taylanbil commented 3 years ago

closing, please re-open if needed.