pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

xm.save() hangs inside epoch loop #2712

Closed mobassir94 closed 3 years ago

mobassir94 commented 3 years ago

i am trying to save my trained models weight file based on maximum validation accuracy using kaggle kernels and with tpu v3-8 here is the code that i tried : https://pastebin.com/Uyz5iBFF

please check the train_model() function. when i try to train the model,it hangs here :

xm.save(model.state_dict(), os.path.join(best_file))

but if i remove that line of code then everything works fine,you can see i have used this line of code :

xm.save(model.state_dict(), os.path.join(f'{kernel_type}_final_fold{fold}.pth'))

for saving each fold's model,,,this one works fine but for best epoch model saving based on best validation accuracy and inside epoch loop it is not working,where am i making mistakes?i am working on kaggle private notebook,so just shared the minimal code through pastebin,sorry if it's not enough to understand the bug then i will try to share the notebook through colab.

taylanbil commented 3 years ago

The root cause is, xm.save has a rendezvous in it. All cores must sync in this rendezvous, otherwise the processes will hang.

So, here's why this happens for you:

Causing a hang.

To resolve this, there are 2 easy solutions.

  1. mesh_reduce the accuracy values before setting acc_max as in here, so acc and acc_max stays synced.
  2. Do not sample the validation dataset, set valid_sampler to None, and thus every core does the same exact work, guaranteeing that acc and acc_max stays the same.
mobassir94 commented 3 years ago

@taylanbil thank you, if i set valid_sampler to None then will it hamper training performance? will it slow down training?

taylanbil commented 3 years ago

You would be doing duplicated, unnecessary work, yes. But it is in validation, and typically validation datasets are small, and doing only forward passes is fast. So the slowdown should be negligible, but I don't know your specifics, so cannot comment further.

mobassir94 commented 3 years ago

i've tried mesh_reduce and now i get this error :

Exception in device=TPU:2: Aborted: Session ca406e137d186162 is not found.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "<timed exec>", line 4, in _mp_fn
  File "<ipython-input-17-55fa67936603>", line 85, in train_model
    val_loss, acc = val_epoch(model,device,para_loader.per_device_loader(device))
  File "<ipython-input-16-e72e94103afd>", line 43, in val_epoch
    val_loss.append(loss.detach().cpu().numpy())
RuntimeError: Aborted: Session ca406e137d186162 is not found.
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
<timed exec> in <module>

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
    392         join=join,
    393         daemon=daemon,
--> 394         start_method=start_method)
    395 
    396 

/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    203 
    204     # Loop on join until it returns True or raises an exception.
--> 205     while not context.join():
    206         pass
    207 

/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    150                     error_pid=failed_process.pid,
    151                     exit_code=exitcode,
--> 152                     signal_name=name
    153                 )
    154             else:

ProcessExitedException: process 6 terminated with signal SIGABRT

is it because of large batch size or number of workers? the error comes after training for 10-12 minutes

taylanbil commented 3 years ago

can you show diff of how you added it?

taylanbil commented 3 years ago

is acc a tensor when you mesh reduce it? can you try with acc = acc.cpu() instead?

mobassir94 commented 3 years ago

the error comes even without mesh_reduce and without acc>acc_max code in epoch loop,so the error is probably related to something else,here is the full code : https://colab.research.google.com/drive/186HNlc1QvmfAC0RVRSxyDEkqfPn_geuF?usp=sharing

taylanbil commented 3 years ago

if u revert back, do you get the hang or do you get this error?

mobassir94 commented 3 years ago

now i am getting this error even after reverting

taylanbil commented 3 years ago

yeah that seems unrelated to mesh reduce. is your tpu available still? maybe preempted?

mobassir94 commented 3 years ago

yeah,tpu available,i have prepared another pytorch xla notebook which is working fine and i tried that notebook few minutes ago and everything was fine,means tpu is available,i have a gut feeling that my val_epoch() function is having some issue

in gpu i had working code like this :

def val_epoch(loader, get_output=False):

    model.eval()
    val_loss = []
    LOGITS = []
    PREDS = []
    TARGETS = []

    with torch.no_grad():
        for (data, target) in tqdm(loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)

            loss = criterion(logits, target)

            pred = logits.softmax(1).argmax(1).detach()
            LOGITS.append(logits)
            PREDS.append(pred)
            TARGETS.append(target)

            val_loss.append(loss.detach().cpu().numpy())
        val_loss = np.mean(val_loss)

    LOGITS = torch.cat(LOGITS).cpu().numpy()
    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    acc = (PREDS == TARGETS).mean() * 100.

    if get_output:
        return LOGITS
    else:
        return val_loss, acc

in tpu i was trying this instead :

def val_epoch(model,device,loader, get_output=False):

    model.eval()
    val_loss = []
    LOGITS = []
    PREDS = []
    TARGETS = []

    #with torch.no_grad():
    for (data, target) in tqdm(loader):
        data, target = data.to(device), target.to(device)
        logits = model(data)

        loss = criterion(logits, target)

        pred = logits.softmax(1).argmax(1).detach() 

        LOGITS.append(logits)
        #xm.master_print(target,logits,pred)
        PREDS.append(pred)
        TARGETS.append(target)

        val_loss.append(loss.detach().cpu().numpy())
    val_loss = np.mean(val_loss)

    LOGITS = torch.cat(LOGITS).cpu().detach().numpy()
    PREDS = torch.cat(PREDS).cpu().detach().numpy()
    TARGETS = torch.cat(TARGETS).detach().cpu().numpy()

    acc1 = (PREDS == TARGETS).mean() * 100.

    acc = xm.mesh_reduce('test_accuracy', acc1, np.mean)

    acc = acc.cpu()
    xm.master_print(acc,acc1,PREDS,TARGETS)

    if get_output:
        return LOGITS
    else:
        return val_loss, acc
mobassir94 commented 3 years ago

with this code everything works fine now :

def val_epoch(model,device,loader, get_output=False):

    model.eval()
    val_loss = []
    LOGITS = []
    PREDS = []
    TARGETS = []

    with torch.no_grad():
        for (data, target) in tqdm(loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)

            loss = criterion(logits, target)

            pred = logits.softmax(1).argmax(1).detach() 
            #xm.master_print("pred",pred)
            LOGITS.append(logits)
            PREDS.append(pred)
            TARGETS.append(target)

            val_loss.append(loss.detach().cpu().numpy())
        val_loss = np.mean(val_loss)

    LOGITS = torch.cat(LOGITS).cpu().numpy()
    xm.master_print("LOGITS",LOGITS)
    PREDS = torch.cat(PREDS).cpu().numpy()
    xm.master_print("preds",PREDS)
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    xm.master_print("TARGETS",TARGETS)
    acc1 = (PREDS == TARGETS).mean() * 100.
    xm.master_print("acc",acc1)
    acc = xm.mesh_reduce('test_accuracy', acc1, np.mean)
    #acc = acc.cpu()
    xm.master_print("final acc",acc)

    if get_output:
        return LOGITS
    else:
        return val_loss, acc

thank you @taylanbil

mobassir94 commented 3 years ago

sorry, 1 last question @taylanbil

after trying the code above everything works fine and weights are getting saved but validation accuracy remains 0.0 always,can you tell me why? it works fine when we use gpu(you can check the gpu baseline here : https://www.kaggle.com/haqishen/baseline-modified-from-previous-competition)

taylanbil commented 3 years ago

That's more difficult to spot why, are you observing any patterns in your printed targets, preds, accuracies? are all accuracies zero before mesh reducing? are all preds the same? etc.

mobassir94 commented 3 years ago

after doing these :

PREDS = torch.cat(PREDS).cpu().numpy()
TARGETS = torch.cat(TARGETS).cpu().numpy()
xm.master_print("Targets".TARGETS)
xm.master_print("PREDS",PREDS)
acc1 = (PREDS == TARGETS).mean() * 100.
xm.master_print("before Mesh ",acc1)
acc = xm.mesh_reduce('test_accuracy', acc1, np.mean)
xm.master_print("After mesh ",acc)

i get output like this :

Targets [3 4 3 3 3 3 3 1] PREDS [9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807] before Mesh 0.0

After mesh 0.0

the moment i closed this issue i was not getting 0.0 using exact same code,not sure what happened

taylanbil commented 3 years ago

wow those predictions look way off. maybe something wrong with softmax? are u using the latest torch xla nightly?

mobassir94 commented 3 years ago

Yes,i am using latest torch xla nightly. I am not sure why it worked just before i closed this issue? After restarting Notebook it is giving 0.0 always.not understanding where i am making mistakes

taylanbil commented 3 years ago

can u print logits prior to applying softmax? there may be an issue with it.

you could also try the 1.7 version maybe?

mobassir94 commented 3 years ago

i tried 1.7 version and printed logits prior to applying softmax and i get :

LOGITS tensor([[-0.5781, -1.0312, -0.8438, -0.5859, 0.9023], [-0.5898, -1.0234, -0.7773, -0.7266, 0.8164], [-0.5938, -1.0781, -0.7930, -0.6680, 0.9492], [-0.4766, -1.0703, -0.7344, -0.6680, 0.8711], [-0.6016, -1.0547, -0.7812, -0.6562, 0.9492], [-0.6758, -1.1172, -0.7227, -0.7266, 1.0156], [-0.7188, -1.0781, -0.8164, -0.6836, 0.9453], [-0.6562, -0.9453, -0.8945, -0.5820, 0.8203]], device='xla:1', grad_fn=) PREDSjn tensor([9223372036854775807, 9223372036854775807, 9223372036854775807, 9223372036854775807, 9223372036854775807, 9223372036854775807, 9223372036854775807, 9223372036854775807], device='xla:1')

Targets [2 3 3 2 4 3 1 4] PREDS [9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807 9223372036854775807] before Mesh 0.0 After mesh 0.0

taylanbil commented 3 years ago

there must be a problem with softmax then. Although I don't understand how it was ever working for you.

mobassir94 commented 3 years ago

using this exact same data i have prepared another pytorch xla notebook where i have something like this instead :

def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []

    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        image_preds = model(imgs)   #output = model(input)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]

        loss = loss_fn(image_preds, image_labels)

        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    acc = (image_preds_all==image_targets_all).mean()
    #LOGGER.debug('validation multi-class accuracy = {:.4f}'.format(acc))
    accuracy = xm.mesh_reduce('test_accuracy', acc, np.mean)
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()
    return accuracy
it works fine there,i got no 0.0 for that
taylanbil commented 3 years ago

oh so criterion was already applying softmax, and you were applying a second one that threw off the results, do I understand right?

Also, before transfering things to cpu, if you mark_step, you'll get faster performance once compilations stabilize; here:

        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]

use xm.mark_step() just before that. Otherwise the graphs are executed prematurely to transfer data to device right there, and then executed again (unnecessarily) once the step is marked.

mobassir94 commented 3 years ago

I am using this Notebook : https://www.kaggle.com/haqishen/baseline-modified-from-previous-competition

Nothing extraordinary,i am just trying to convert that gpu baseline kernel into tpu,

After trying valid_one_epoch() function as i showed in my last comment, the code worked, there i am doing argmax() but in that gpu baseline kernel qishen ha Did softmax(1).argmax(1)

You can see within 10 epoch he got 0.85+ validation accuracy And using same baseline in tpu i get 0.62 even after 20 epoch I am not sure why that huge performance difference, I spent a lot of time on this tpu baseline and now i am getting very bad result I tried 440x440x3 images,efficientnet_b2 and batch 8x8 Qishen ha used 256x256x3 images, effnetb0 and batch 64 to get 0.85+ accuracy I have shared my full code with you in comment above If you match that with this gpu baseline : https://www.kaggle.com/haqishen/baseline-modified-from-previous-competition

I believe you Won't find any mistake,i am Just simply trying to replicate similar result in tpu so i Didn't change the baseline much Kept almost everything as that gpu baseline(just updated the valid_one_epoch function as i mentioned above) And now i am getting terrible result Validation accuracy around 60% and it never improves then Where qishen ha's gpu baseline got 0.85+ and it was improving frequently

taylanbil commented 3 years ago

Took a quick look; the kaggle kernel uses a pre-trained model as seen in cells 7 and 8. the colab link from above doesn't use pretrained weights. Generally, this type of discrepancy leads to big differences in downstream tasks.

mobassir94 commented 3 years ago

My bad.what a terrible mistake.i am sorry.i Didn't notice that i am not using pretrained model, now i understand why the training was crazy fast when i expected 1 epoch to take ~15 minutes and it took 3-4 Minutes :D Thank you a ton ♥

mobassir94 commented 3 years ago

hi @taylanbil the problem of saving weight inside epoch loop using xm.save() still persist,please check this : https://github.com/pytorch/xla/issues/2724