pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Training code hangs at various locations #945

Closed ibeltagy closed 5 years ago

ibeltagy commented 5 years ago

The training loop hangs at various locations. The hanging started happening after some changes to how the dataset is read (but this could be unrelated). It feels a lot like this issue https://github.com/pytorch/xla/issues/821, but I am using the nightly build which is supposed to have the issue fixed.

It hangs at the locations shown below. It also hangs at the start of the second epoch (if model.save and return loss.item were deleted)

def train_loop_fn(model, loader, device, context):
    for x, (data, target) in loader:

       # run model, compute loss ... etc

       if device == first_device:
            model.save()  # << === hangs here

    return average_loss.item() # << === and hangs here as well

Below is the gdb log for the hanging at model.save()

Thread 1 "python" received signal SIGINT, Interrupt.
0x00007ffff7bcb556 in futex_abstimed_wait_cancelable (private=0, abstime=0x0, expected=0, futex_word=0x7ffe4c000c10) at ../sysdeps/unix/sysv/linux/futex-internal.h:205
205 in ../sysdeps/unix/sysv/linux/futex-internal.h
#0  0x00007ffff7bcb556 in futex_abstimed_wait_cancelable (private=0, abstime=0x0, expected=0, futex_word=0x7ffe4c000c10) at ../sysdeps/unix/sysv/linux/futex-internal.h:205
#1  do_futex_wait (sem=sem@entry=0x7ffe4c000c10, abstime=0x0) at sem_waitcommon.c:111
#2  0x00007ffff7bcb604 in __new_sem_wait_slow (sem=0x7ffe4c000c10, abstime=0x0) at sem_waitcommon.c:181
#3  0x000055555563ff76 in PyThread_acquire_lock_timed () at /tmp/build/80754af9/python_1546130271559/work/Python/thread_pthread.h:386
#4  0x00005555556d21ac in acquire_timed (timeout=-1000000000, lock=0x7ffe4c000c10) at /tmp/build/80754af9/python_1546130271559/work/Modules/_threadmodule.c:68
#5  lock_PyThread_acquire_lock () at /tmp/build/80754af9/python_1546130271559/work/Modules/_threadmodule.c:151
#6  0x0000555555665744 in _PyCFunction_FastCallDict () at /tmp/build/80754af9/python_1546130271559/work/Objects/methodobject.c:231
#7  0x00005555556ec42c in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4851
#8  0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#9  0x00005555556e58e4 in _PyEval_EvalCodeWithName () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4166
#10 0x00005555556e6771 in fast_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4992
#11 0x00005555556ec505 in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4872
#12 0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#13 0x00005555556e58e4 in _PyEval_EvalCodeWithName () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4166
#14 0x00005555556e6771 in fast_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4992
#15 0x00005555556ec505 in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4872
#16 0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#17 0x00005555556e6bab in _PyFunction_FastCall (globals=<optimized out>, nargs=3, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4933
#18 _PyFunction_FastCallDict () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:5035
#19 0x0000555555665b0f in _PyObject_FastCallDict () at /tmp/build/80754af9/python_1546130271559/work/Objects/abstract.c:2310
#20 0x000055555566a6a3 in _PyObject_Call_Prepend () at /tmp/build/80754af9/python_1546130271559/work/Objects/abstract.c:2373
#21 0x000055555566554e in PyObject_Call () at /tmp/build/80754af9/python_1546130271559/work/Objects/abstract.c:2261
#22 0x00005555556bfa91 in slot_tp_call () at /tmp/build/80754af9/python_1546130271559/work/Objects/typeobject.c:6207
#23 0x000055555566592b in _PyObject_FastCallDict () at /tmp/build/80754af9/python_1546130271559/work/Objects/abstract.c:2331
#24 0x00005555556ec57e in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4875
#25 0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#26 0x00005555556e5bfe in _PyEval_EvalCodeWithName () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4166
#27 0x00005555556e6771 in fast_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4992
#28 0x00005555556ec505 in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4872
#29 0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#30 0x00005555556e7289 in _PyEval_EvalCodeWithName (qualname=0x0, name=<optimized out>, closure=0x0, kwdefs=0x0, defcount=0, defs=0x0, kwstep=2, kwcount=<optimized out>, kwargs=0x0, kwnames=0x0, argcount=0, args=0x0, locals=0x7ffff7f64090, globals=0x7ffff7f64090, _co=0x7ffff6e86150) at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4166
#31 PyEval_EvalCodeEx () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4187
#32 0x00005555556e801c in PyEval_EvalCode (co=co@entry=0x7ffff6e86150, globals=globals@entry=0x7ffff7f64090, locals=locals@entry=0x7ffff7f64090) at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:731
#33 0x000055555570ed97 in builtin_exec_impl.isra.11 (locals=0x7ffff7f64090, globals=0x7ffff7f64090, source=0x7ffff6e86150) at /tmp/build/80754af9/python_1546130271559/work/Python/bltinmodule.c:983
#34 builtin_exec () at /tmp/build/80754af9/python_1546130271559/work/Python/clinic/bltinmodule.c.h:283
#35 0x0000555555665681 in _PyCFunction_FastCallDict () at /tmp/build/80754af9/python_1546130271559/work/Objects/methodobject.c:234
#36 0x00005555556ec42c in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4851
#37 0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#38 0x00005555556e58e4 in _PyEval_EvalCodeWithName () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4166
#39 0x00005555556e6771 in fast_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4992
#40 0x00005555556ec505 in call_function () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4872
#41 0x000055555571138a in _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:3335
#42 0x00005555556e7289 in _PyEval_EvalCodeWithName (qualname=0x0, name=<optimized out>, closure=0x0, kwdefs=0x0, defcount=1, defs=0x7ffff6ea63e0, kwstep=2, kwcount=<optimized out>, kwargs=0x0, kwnames=0x0, argcount=<optimized out>, args=0x5555558e1490, locals=0x0, globals=<optimized out>, _co=0x7ffff6e8fdb0) at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4166
#43 PyEval_EvalCodeEx () at /tmp/build/80754af9/python_1546130271559/work/Python/ceval.c:4187
#44 0x00005555556e8109 in function_call () at /tmp/build/80754af9/python_1546130271559/work/Objects/funcobject.c:604
#45 0x000055555566554e in PyObject_Call () at /tmp/build/80754af9/python_1546130271559/work/Objects/abstract.c:2261
#46 0x0000555555763b07 in RunModule () at /tmp/build/80754af9/python_1546130271559/work/Modules/main.c:215
#47 0x000055555576e1b6 in Py_Main () at /tmp/build/80754af9/python_1546130271559/work/Modules/main.c:752
#48 0x000055555563702e in main () at /tmp/build/80754af9/python_1546130271559/work/Programs/python.c:69
#49 0x00007ffff783d2e1 in __libc_start_main (main=0x555555636f40 <main>, argc=20, argv=0x7fffffffe968, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffe958) at ../csu/libc-start.c:291
#50 0x0000555555717e0e in _start () at ../sysdeps/x86_64/elf/start.S:103
ibeltagy commented 5 years ago

I used thread apply all bt, not sure why the log only shows one thread

dlibenzi commented 5 years ago

That looks like a Python thread. A thread apply all bt would be better. If you remove the model.save(), does it still hang? What about single core?

ibeltagy commented 5 years ago

Updates about this:

dlibenzi commented 5 years ago

A thread apply all bt should dump every thread (python, C++, ...).

The data-dependence suggests me it is not actually hanging, but recompiling due to dynamic shapes.

A per step print(torch_xla._XLAC._xla_metrics_report()) will be able to detect that.

ibeltagy commented 5 years ago

mm, I doubt it is a recompilation issue because there are no dynamic shapes in the code. Also, slow recompilation might explain slow loss.item() but won't explain slow model.save(), right?

A per step print(torch_xla._XLAC._xla_metrics_report())

How would that work? I mean even if I call it every step, it wont' be called once the code hangs, and I won't be able to see the report.

dlibenzi commented 5 years ago

The fact that depends on data, like you mentioned, feels like a recompilation issue. For how long is it hanging (after how long you hit ^C)?

Are you building from source or using nightly builds (and nightly TPU VM)?

Do you have links to your code?

ibeltagy commented 5 years ago

I hit ^C after a few hours, and using the docker image of the nightly build

I found out why I was getting backtrace of one thread. Here's the backtrace of all threads https://gist.github.com/ibeltagy/9d0ec1bd1c71d74f2e869260b3187fd6#file-bt-txt

dlibenzi commented 5 years ago

Thanks! If you run the stack traces through scripts/stack_trace_parse.py it joins same traces and provide a more useful view:

https://gist.github.com/dlibenzi/48a52a6ad0d0384afb4c45df0a045af6

It seems to be hanging on Execute, but this is not single core. There are 8 execute going. To narrow this down, can we try single core? Also, it'd be really helpful if you could provide github links to your model's code (or at least the main loop).

ibeltagy commented 5 years ago

I couldn't replicate the hanging with a single core.

The model code is just BERT, https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py

The training code is messy, so I copied the important parts here, https://gist.github.com/ibeltagy/2cec9bf7f5a429b12dba90717baa9635

This is the log when it hangs at line 88. You already have the log for when it hangs at line 63 https://gist.github.com/ibeltagy/27e33876d89785c7dc4c4a67700379e9

dlibenzi commented 5 years ago

To further marrow down, can we try disabling model saving, and use gradient_accumulation_steps=1 (and run with 8 cores)?

ibeltagy commented 5 years ago

The previous log (https://gist.github.com/dlibenzi/48a52a6ad0d0384afb4c45df0a045af6) is with gradient_accumulation_steps=1, 8 cores, no model saving. It hangs at line 63.

ibeltagy commented 5 years ago

can we try disabling model saving,

also, I don't think it is related to saving per se. I feel like it is more about synchronization between threads, and some threads are waiting for each others.

dlibenzi commented 5 years ago

Can you try to print like below, before the return statement of the train loop function?

print(torch_xla._XLAC._get_xla_tensors_text([tr_loss, tr_segment_pred_loss, tr_masked_lm_loss]))
ibeltagy commented 5 years ago

nothing surprising,

IR {
  %0 = f32[] xla::device_data(), device=TPU:1, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:1, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:1, ROOT=2
}

IR {
  %0 = f32[] xla::device_data(), device=TPU:4, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:4, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:4, ROOT=2
}

IR {
  %0 = f32[] xla::device_data(), device=TPU:7, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:7, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:7, ROOT=2
}

IR {
  %0 = f32[] xla::device_data(), device=TPU:3, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:3, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:3, ROOT=2
}

IR {
  %0 = f32[] xla::device_data(), device=TPU:2, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:2, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:2, ROOT=2
}
IR {
  %0 = f32[] xla::device_data(), device=TPU:0, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:0, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:0, ROOT=2
}

IR {
  %0 = f32[] xla::device_data(), device=TPU:6, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:6, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:6, ROOT=2
}

IR {
  %0 = f32[] xla::device_data(), device=TPU:5, ROOT=0
  %1 = f32[] xla::device_data(), device=TPU:5, ROOT=1
  %2 = f32[] xla::device_data(), device=TPU:5, ROOT=2
}
ibeltagy commented 5 years ago

I think the problem is happening when the threads are not doing the same amount of work, and they don't reach the return statement of the training loop tpu_training_loop together at the same time. This happens when the data size is not divisible by batch size, and one thread is faster than the rest because its batch size in the last iteration is smaller than the rest.

ibeltagy commented 5 years ago

It seemed that whenever it hangs, usually 7 cores reach the return statement of the training function, then the 8th core reaches the end a bit later. Making sure that the data size is divisible by batch size prevents this from happening.

dlibenzi commented 5 years ago

I can see that happening in case the number of batches returned by your data loader, is exactly divisible by num_cores, and the last has a size which is not the same. Can you try this?

https://github.com/pytorch/xla/pull/966

ibeltagy commented 5 years ago

ok, will try, but this still doesn't fix the root cause that threads shouldn't deadlock when one is slower than the others

dlibenzi commented 5 years ago

ok, will try, but this still doesn't solve the root cause that threads shouldn't deadlock when one is slower than the others

It should not be a matter of slower (the PR), but one core either getting a badly sized last batch, or not getting it al all (if drop_last=True).

I am assuming here that model saving is disabled.

ibeltagy commented 5 years ago

yes, model saving is disabled, but if I enable it again, it will slow down one of the threads and trigger this deadlock again. That's why I am saying this PR is not fixing the root cause of the problem.

dlibenzi commented 5 years ago

Why model saving disabled? In replication mode, all the computations issued to the core must be exactly the same (both in terms of operations, and the shapes of the tensors). Doing an if-device-is-0-do-this kind of thing, will likely trigger an extra TPU computation on device 0, which is not the same as the one running on the other cores. If this is done on the inference path (where we do not issue an xm.optimizer_step() - hence no Cross Replica Sum will be in flight) no biggie, but if this is done in training, hanging will likely happen.

A better way to do that, is to save from all cores, or, copy all weights to CPU on all cores, and then issue a save of the CPU tensors from one core only.

ibeltagy commented 5 years ago

Interesting. Thanks for the clarification.

ibeltagy commented 5 years ago

I didn't try your PR yet, but I implemented something similar yesterday and it has been training for 8 hours straight with no hanging.