Open adhamali450 opened 1 year ago
How do you configure PyTorch/XLA? We recommended PJRT. Can you do a sanity check following https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpu (Current master doc still point to 1.13, so maybe try https://github.com/pytorch/xla/blob/JackCaoG-patch-2/docs/pjrt.md#tpu instead) and see if resnet runs? There are some subtle differences but I believe TPU v3 should also runs with PJRT.
I followed the https://github.com/pytorch/xla/blob/JackCaoG-patch-2/docs/pjrt.md#tpu and here's the updated code as per the readme:
# TPU-specific libraries (must-haves)
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torch_xla.experimental.pjrt_backend
import torch_xla.experimental.pjrt as pjrt
transform = transforms.Compose([transforms.Lambda(lambda wav: audio_to_mel(wav, random_crop=True)['real_imag_pair'])])
train_set = AudioToImageFolder(os.path.join(
DATASET_PATH, "train/"), transform=transform)
valid_set = AudioToImageFolder(os.path.join(
DATASET_PATH, "val/"), transform=transform)
def map_fn(index, flags):
device = xm.xla_device()
+ dist.init_process_group('xla', init_method='pjrt://')
# Perpare for saving results
if not xm.is_master_ordinal():
xm.rendezvous('create_dirs_once')
for func in [
lambda: os.mkdir(os.path.join('.', f'./results')),
lambda: os.mkdir(os.path.join('.', f'./results/model')),
lambda: os.mkdir(os.path.join('.', f'./results/plots'))]:
try:
func()
except Exception as error:
print(error)
continue
if xm.is_master_ordinal():
xm.rendezvous('create_dirs_once')
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_set,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_set,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False
)
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=flags['batch_size'],
sampler=train_sampler,
num_workers=flags['num_workers'],
drop_last=True
)
valid_loader = torch.utils.data.DataLoader(
valid_set,
batch_size=flags['batch_size'],
sampler=valid_sampler,
num_workers=flags['num_workers'],
drop_last=True
)
train_device_loader = pl.MpDeviceLoader(train_loader, device)
valid_device_loader = pl.MpDeviceLoader(valid_loader, device)
# para_loader_train = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
# para_loader_valid = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)
encoder = DenseEncoder(flags['data_depth'], flags['hidden_size'],flags['channels_size'])
decoder = DenseDecoder(flags['data_depth'], flags['hidden_size'],flags['channels_size'])
critic = BasicCritic(flags['hidden_size'],flags['channels_size'])
encoder = encoder.to(device)
decoder = decoder.to(device)
critic = critic.to(device)
cr_optimizer = Adam(critic.parameters(), lr=1e-4)
en_de_optimizer = Adam(list(decoder.parameters()) + list(encoder.parameters()), lr=1e-4)
metrics = {field: list() for field in flags['METRIC_FIELDS']}
# Training loop
iter_train_critic=0
iter_train_enc_dec=0
iter_valid=0
for ep in range(0, flags['num_epochs']):
encoder = encoder.train()
decoder = decoder.train()
critic = critic.train()
print("Epoch %d" %(ep+1))
for cover, *rest in notebook.tqdm(train_device_loader):
iter_train_critic+=1
gc.collect()
cover = cover.to(device)
N, _, H, W = cover.size()
# sampled from the discrete uniform distribution over 0 to 2
payload = torch.zeros((N, flags['data_depth'], H, W),
device=device).random_(0, 2)
generated = encoder.forward(cover, payload)
cover_score = torch.mean(critic.forward(cover))
generated_score = torch.mean(critic.forward(generated))
cr_optimizer.zero_grad()
(cover_score - generated_score).backward(retain_graph=False)
xm.optimizer_step(cr_optimizer)
for p in critic.parameters():
p.data.clamp_(-0.1, 0.1)
metrics['train.cover_score'].append(cover_score.item())
metrics['train.generated_score'].append(generated_score.item())
for cover, *rest in notebook.tqdm(train_device_loader):
iter_train_enc_dec+=1
gc.collect()
cover = cover.to(device)
N, _, H, W = cover.size()
# sampled from the discrete uniform distribution over 0 to 2
payload = torch.zeros((N,flags['data_depth'], H, W),
device=device).random_(0, 2)
generated = encoder.forward(cover, payload)
decoded = decoder.forward(generated)
encoder_mse = mse_loss(generated, cover)
decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
decoder_acc = (decoded >= 0.0).eq(
payload >= 0.5).sum().float() / payload.numel()
generated_score = torch.mean(critic.forward(generated))
en_de_optimizer.zero_grad()
(100 * encoder_mse + decoder_loss +
generated_score).backward() # Why 100?
xm.optimizer_step(en_de_optimizer)
metrics['train.encoder_mse'].append(encoder_mse.item())
metrics['train.decoder_loss'].append(decoder_loss.item())
metrics['train.decoder_acc'].append(decoder_acc.item())
# Validation after every epoch
encoder = encoder.eval()
decoder = decoder.eval()
critic = critic.eval()
for cover, *rest in notebook.tqdm(valid_device_loader):
iter_valid+=1
gc.collect()
cover = cover.to(device)
N, _, H, W = cover.size()
# sampled from the discrete uniform distribution over 0 to 2
payload = torch.zeros((N,flags['data_depth'], H, W),
device=device).random_(0, 2)
generated = encoder.forward(cover, payload)
decoded = decoder.forward(generated)
encoder_mse = mse_loss(generated, cover)
decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
decoder_acc = (decoded >= 0.0).eq(
payload >= 0.5).sum().float() / payload.numel()
generated_score = torch.mean(critic.forward(generated))
cover_score = torch.mean(critic.forward(cover))
ssim_=ssim(cover, generated)
psnr_=10 * torch.log10(4 / encoder_mse)
bbp_=flags['data_depth']* (2 * decoder_acc.item() - 1)
metrics['val.encoder_mse'].append(encoder_mse.item())
metrics['val.decoder_loss'].append(decoder_loss.item())
metrics['val.decoder_acc'].append(decoder_acc.item())
metrics['val.cover_score'].append(cover_score.item())
metrics['val.generated_score'].append(generated_score.item())
metrics['val.ssim'].append(ssim_.item())
metrics['val.psnr'].append(psnr_.item())
metrics['val.bpp'].append(bbp_)
xm.master_print('encoder_mse: %.3f - decoder_loss: %.3f - decoder_acc: %.3f - cover_score: %.3f - generated_score: %.3f - ssim: %.3f - psnr: %.3f - bpp: %.3f'
%(encoder_mse.item(),decoder_loss.item(),decoder_acc.item(),cover_score.item(),generated_score.item(), ssim_.item(),psnr_.item(),bbp_))
if not xm.is_master_ordinal():
xm.rendezvous('save_only_once')
save_model(encoder, decoder, critic, en_de_optimizer, cr_optimizer, metrics, flags['num_epochs'], flags)
if xm.is_master_ordinal():
xm.rendezvous('save_only_once')
flags['METRIC_FIELDS'] = [
'val.encoder_mse',
'val.decoder_loss',
'val.decoder_acc',
'val.cover_score',
'val.generated_score',
'val.ssim',
'val.psnr',
'val.bpp',
'train.encoder_mse',
'train.decoder_loss',
'train.decoder_acc',
'train.cover_score',
'train.generated_score'
]
if __name__ == '__main__':
+ os.environ['PJRT_DEVICE'] = 'TPU'
xmp.spawn(map_fn, args=(flags,), nprocs=1, start_method='fork')
Also, here is the Colab notebook if you want to further investigate. You have an edit access as well.
PJRT doesn't work on colab, were you able to get above code to run on TPUVM? I would recommend two thing
xmp.spawn
and call map_fn
directly to see if above code runs for a single core.You can also checkout our troubleshooting doc. You can find more doc under https://pytorch.org/xla/release/2.0/index.html#
You mentioned that subtle changes can be done to the code. Can you point them out and also if they solve the issue or not?
I'm fairly new to this and have little to no experience. I had a notebook running PyTorch that I wanted to run a Google Cloud TPU VM. Machine specs:
I wanted to use all the 8 cores that I have for training a a neural network. When the interpreter reaches the lines where I move models to be trained to the device. It does nothing and prints what appears to be a warning:
https://symbolize.stripped_domain/r/?trace=7fcddb59105e,7fcfbafdd08f,.
The code
So, Please if you see anything wrong with this code or have any suggestion, I really need the help for this to work.