pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Enable multiprocessing on pytorch XLA for TPU VM #4893

Open adhamali450 opened 1 year ago

adhamali450 commented 1 year ago

I'm fairly new to this and have little to no experience. I had a notebook running PyTorch that I wanted to run a Google Cloud TPU VM. Machine specs:

- Ubuntu
- TPU v3-8
- pt-2.0

I wanted to use all the 8 cores that I have for training a a neural network. When the interpreter reaches the lines where I move models to be trained to the device. It does nothing and prints what appears to be a warning: https://symbolize.stripped_domain/r/?trace=7fcddb59105e,7fcfbafdd08f,.

The code

# TPU-specific libraries (must-haves)
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

transform = transforms.Compose([transforms.Lambda(lambda wav: audio_to_mel(wav)['real_imag_pair'])])

train_set = AudioToImageFolder(os.path.join(
    DATASET_PATH, "train/"), transform=transform)
valid_set = AudioToImageFolder(os.path.join( 
    DATASET_PATH, "val/"), transform=transform)

def map_fn(index, flags):
    xm.master_print("Creating the datasamplers")
    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_set,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(), 
      shuffle=True
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_set, 
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(), 
        shuffle=False
    )

    xm.master_print("Creating the dataloaders")
    train_loader = torch.utils.data.DataLoader(
      train_set,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True
    )

    valid_loader = torch.utils.data.DataLoader(
      valid_set,
      batch_size=flags['batch_size'],
      sampler=valid_sampler,
      num_workers=flags['num_workers'],
      drop_last=True
    )

    device = xm.xla_device()

    xm.master_print("Creating the para dataloaders")
    para_loader_train = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    para_loader_valid = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)

    xm.master_print("Creating the models")
    encoder = DenseEncoder(data_depth, hidden_size,flags['channels_size'])
    decoder = DenseDecoder(data_depth, hidden_size,flags['channels_size'])
    critic = BasicCritic(hidden_size,flags['channels_size'])

    xm.master_print("Sending to device")

    # It prints 'Sending to device' and then crashes...

    encoder = encoder.to(device)
    decoder = decoder.to(device)
    critic = critic.to(device)

    xm.master_print("Creating the optimizers") # not printed
    cr_optimizer = Adam(critic.parameters(), lr=1e-4)
    en_de_optimizer = Adam(list(decoder.parameters()) + list(encoder.parameters()), lr=1e-4)

    # Training loop
    iter_train_critic=0
    iter_train_enc_dec=0
    iter_valid=0

    xm.master_print("Will train...")

    for ep in range(0, flags['num_epochs']):
        encoder = encoder.train()
        decoder = decoder.train()
        critic = critic.train()

        print("Epoch %d" %(ep+1))
        for cover, *rest in tqdm(para_loader_train):
            iter_train_critic+=1
            gc.collect()
            cover = cover.to(device)
            N, _, H, W = cover.size()
            # sampled from the discrete uniform distribution over 0 to 2
            payload = torch.zeros((N, data_depth, H, W),
                                  device=device).random_(0, 2)
            generated = encoder.forward(cover, payload)
            cover_score = torch.mean(critic.forward(cover))
            generated_score = torch.mean(critic.forward(generated))

            cr_optimizer.zero_grad()
            (cover_score - generated_score).backward(retain_graph=False)

            xm.optimizer_step(cr_optimizer)

            for p in critic.parameters():
                p.data.clamp_(-0.1, 0.1)

        for cover, *rest in tqdm(para_loader_train):
            iter_train_enc_dec+=1
            gc.collect()
            cover = cover.to(device)
            N, _, H, W = cover.size()
            # sampled from the discrete uniform distribution over 0 to 2
            payload = torch.zeros((N, data_depth, H, W),
                                  device=device).random_(0, 2)
            generated = encoder.forward(cover, payload)
            decoded = decoder.forward(generated)
            encoder_mse = mse_loss(generated, cover)
            decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
            decoder_acc = (decoded >= 0.0).eq(
                payload >= 0.5).sum().float() / payload.numel()
            generated_score = torch.mean(critic.forward(generated))

            en_de_optimizer.zero_grad()
            (100 * encoder_mse + decoder_loss +
             generated_score).backward()  # Why 100?

            xm.optimizer_step(en_de_optimizer)

        # Validation after every epoch

        encoder = encoder.eval()
        decoder = decoder.eval()
        critic = critic.eval()    

        for cover, *rest in tqdm(para_loader_valid):
            iter_valid+=1
            gc.collect()
            cover = cover.to(device)

            N, _, H, W = cover.size()

            # sampled from the discrete uniform distribution over 0 to 2
            payload = torch.zeros((N, data_depth, H, W),
                                  device=device).random_(0, 2)
            generated = encoder.forward(cover, payload)

            decoded = decoder.forward(generated)

            encoder_mse = mse_loss(generated, cover)
            decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
            decoder_acc = (decoded >= 0.0).eq(
                payload >= 0.5).sum().float() / payload.numel()
            generated_score = torch.mean(critic.forward(generated))
            cover_score = torch.mean(critic.forward(cover))

            ssim_=ssim(cover, generated)
            psnr_=10 * torch.log10(4 / encoder_mse)
            bbp_=data_depth * (2 * decoder_acc.item() - 1)

#         print('encoder_mse: %.3f - decoder_loss: %.3f - decoder_acc: %.3f - cover_score: %.3f - generated_score: %.3f - ssim: %.3f - psnr: %.3f - bpp: %.3f'
#           %(encoder_mse.item(),decoder_loss.item(),decoder_acc.item(),cover_score.item(),generated_score.item(), ssim_.item(),psnr_.item(),bbp_))

        xm.master_print("Finished training epoch {}".format(ep))

flags = {}
flags['batch_size'] = 2
flags['num_workers'] = 8 
flags['num_epochs'] = 32 
flags['channels_size'] = 2
# flags['seed'] = 42

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

So, Please if you see anything wrong with this code or have any suggestion, I really need the help for this to work.

JackCaoG commented 1 year ago

How do you configure PyTorch/XLA? We recommended PJRT. Can you do a sanity check following https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpu (Current master doc still point to 1.13, so maybe try https://github.com/pytorch/xla/blob/JackCaoG-patch-2/docs/pjrt.md#tpu instead) and see if resnet runs? There are some subtle differences but I believe TPU v3 should also runs with PJRT.

adhamali450 commented 1 year ago

I followed the https://github.com/pytorch/xla/blob/JackCaoG-patch-2/docs/pjrt.md#tpu and here's the updated code as per the readme:

# TPU-specific libraries (must-haves)
import torch.distributed as dist

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torch_xla.experimental.pjrt_backend
import torch_xla.experimental.pjrt as pjrt

transform = transforms.Compose([transforms.Lambda(lambda wav: audio_to_mel(wav, random_crop=True)['real_imag_pair'])])

train_set = AudioToImageFolder(os.path.join(
    DATASET_PATH, "train/"), transform=transform)
valid_set = AudioToImageFolder(os.path.join( 
    DATASET_PATH, "val/"), transform=transform)

def map_fn(index, flags):
    device = xm.xla_device()
+ dist.init_process_group('xla', init_method='pjrt://')

    # Perpare for saving results
    if not xm.is_master_ordinal():
          xm.rendezvous('create_dirs_once')

    for func in [
            lambda: os.mkdir(os.path.join('.', f'./results')),
            lambda: os.mkdir(os.path.join('.', f'./results/model')),
            lambda: os.mkdir(os.path.join('.', f'./results/plots'))]:
        try:
          func()
        except Exception as error:
          print(error)
          continue

    if xm.is_master_ordinal():
      xm.rendezvous('create_dirs_once')

    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_set,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(), 
      shuffle=True
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_set, 
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(), 
        shuffle=False
    )

    train_loader = torch.utils.data.DataLoader(
      train_set,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True
    )

    valid_loader = torch.utils.data.DataLoader(
      valid_set,
      batch_size=flags['batch_size'],
      sampler=valid_sampler,
      num_workers=flags['num_workers'],
      drop_last=True
    )

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    valid_device_loader = pl.MpDeviceLoader(valid_loader, device)

#     para_loader_train = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
#     para_loader_valid = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)

    encoder = DenseEncoder(flags['data_depth'], flags['hidden_size'],flags['channels_size'])
    decoder = DenseDecoder(flags['data_depth'], flags['hidden_size'],flags['channels_size'])
    critic = BasicCritic(flags['hidden_size'],flags['channels_size'])

    encoder = encoder.to(device)
    decoder = decoder.to(device)
    critic = critic.to(device)

    cr_optimizer = Adam(critic.parameters(), lr=1e-4)
    en_de_optimizer = Adam(list(decoder.parameters()) + list(encoder.parameters()), lr=1e-4)

    metrics = {field: list() for field in flags['METRIC_FIELDS']}

    # Training loop
    iter_train_critic=0
    iter_train_enc_dec=0
    iter_valid=0

    for ep in range(0, flags['num_epochs']):
        encoder = encoder.train()
        decoder = decoder.train()
        critic = critic.train()

        print("Epoch %d" %(ep+1))
        for cover, *rest in notebook.tqdm(train_device_loader):
            iter_train_critic+=1
            gc.collect()

            cover = cover.to(device)
            N, _, H, W = cover.size()
            # sampled from the discrete uniform distribution over 0 to 2
            payload = torch.zeros((N, flags['data_depth'], H, W),
                                  device=device).random_(0, 2)
            generated = encoder.forward(cover, payload)
            cover_score = torch.mean(critic.forward(cover))
            generated_score = torch.mean(critic.forward(generated))

            cr_optimizer.zero_grad()
            (cover_score - generated_score).backward(retain_graph=False)

            xm.optimizer_step(cr_optimizer)

            for p in critic.parameters():
                p.data.clamp_(-0.1, 0.1)

            metrics['train.cover_score'].append(cover_score.item())
            metrics['train.generated_score'].append(generated_score.item())

        for cover, *rest in notebook.tqdm(train_device_loader):
            iter_train_enc_dec+=1
            gc.collect()

            cover = cover.to(device)
            N, _, H, W = cover.size()
            # sampled from the discrete uniform distribution over 0 to 2
            payload = torch.zeros((N,flags['data_depth'], H, W),
                                  device=device).random_(0, 2)
            generated = encoder.forward(cover, payload)
            decoded = decoder.forward(generated)
            encoder_mse = mse_loss(generated, cover)
            decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
            decoder_acc = (decoded >= 0.0).eq(
                payload >= 0.5).sum().float() / payload.numel()
            generated_score = torch.mean(critic.forward(generated))

            en_de_optimizer.zero_grad()
            (100 * encoder_mse + decoder_loss +
             generated_score).backward()  # Why 100?

            xm.optimizer_step(en_de_optimizer)

            metrics['train.encoder_mse'].append(encoder_mse.item())
            metrics['train.decoder_loss'].append(decoder_loss.item())
            metrics['train.decoder_acc'].append(decoder_acc.item())

        # Validation after every epoch

        encoder = encoder.eval()
        decoder = decoder.eval()
        critic = critic.eval()    

        for cover, *rest in notebook.tqdm(valid_device_loader):
            iter_valid+=1
            gc.collect()

            cover = cover.to(device)
            N, _, H, W = cover.size()

            # sampled from the discrete uniform distribution over 0 to 2
            payload = torch.zeros((N,flags['data_depth'], H, W),
                                  device=device).random_(0, 2)
            generated = encoder.forward(cover, payload)

            decoded = decoder.forward(generated)

            encoder_mse = mse_loss(generated, cover)
            decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
            decoder_acc = (decoded >= 0.0).eq(
                payload >= 0.5).sum().float() / payload.numel()
            generated_score = torch.mean(critic.forward(generated))
            cover_score = torch.mean(critic.forward(cover))

            ssim_=ssim(cover, generated)
            psnr_=10 * torch.log10(4 / encoder_mse)
            bbp_=flags['data_depth']* (2 * decoder_acc.item() - 1)

            metrics['val.encoder_mse'].append(encoder_mse.item())
            metrics['val.decoder_loss'].append(decoder_loss.item())
            metrics['val.decoder_acc'].append(decoder_acc.item())
            metrics['val.cover_score'].append(cover_score.item())
            metrics['val.generated_score'].append(generated_score.item())
            metrics['val.ssim'].append(ssim_.item())
            metrics['val.psnr'].append(psnr_.item())
            metrics['val.bpp'].append(bbp_)

        xm.master_print('encoder_mse: %.3f - decoder_loss: %.3f - decoder_acc: %.3f - cover_score: %.3f - generated_score: %.3f - ssim: %.3f - psnr: %.3f - bpp: %.3f'
          %(encoder_mse.item(),decoder_loss.item(),decoder_acc.item(),cover_score.item(),generated_score.item(), ssim_.item(),psnr_.item(),bbp_))

        if not xm.is_master_ordinal():
          xm.rendezvous('save_only_once')

        save_model(encoder, decoder, critic, en_de_optimizer, cr_optimizer, metrics, flags['num_epochs'], flags)

        if xm.is_master_ordinal():
          xm.rendezvous('save_only_once')

flags['METRIC_FIELDS'] = [
    'val.encoder_mse',
    'val.decoder_loss',
    'val.decoder_acc',
    'val.cover_score',
    'val.generated_score',
    'val.ssim',
    'val.psnr',
    'val.bpp',
    'train.encoder_mse',
    'train.decoder_loss',
    'train.decoder_acc',
    'train.cover_score',
    'train.generated_score'
]

if __name__ == '__main__':
+ os.environ['PJRT_DEVICE'] = 'TPU'
    xmp.spawn(map_fn, args=(flags,), nprocs=1, start_method='fork')

Also, here is the Colab notebook if you want to further investigate. You have an edit access as well.

JackCaoG commented 1 year ago

PJRT doesn't work on colab, were you able to get above code to run on TPUVM? I would recommend two thing

  1. Try out resnet example and see it runs
  2. commented out the xmp.spawn and call map_fn directly to see if above code runs for a single core.
  3. Try out smaller codes first instead of a full model if this is your first time running on TPU

You can also checkout our troubleshooting doc. You can find more doc under https://pytorch.org/xla/release/2.0/index.html#

adhamali450 commented 1 year ago

You mentioned that subtle changes can be done to the code. Can you point them out and also if they solve the issue or not?