pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Error During Multi Core TPU Training #6048

Closed mfatih7 closed 9 months ago

mfatih7 commented 10 months ago

Hello

While trying to run my learning loop on multiple cores of TPU v2 I get the error below. Is it related to an XLA error or do I have errors in my script?

best regards

https://symbolize.stripped_domain/r/?trace=7f0116c2be94,7f025a79908f,7f0116af7993,7f0116b0d462,7f0116afbf1f,5d5498,8fdaff&map=06b7eaee513554b0b69f7d4d65fa69f6858d5374:7f01122e2000-7f0120b21e40 
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 58456 (TID 59813) on cpu 3; stack trace: ***
PC: @     0x7f0116c2be94  (unknown)  torch_xla::tensor_methods::all_reduce()
    @     0x7f010e79153a       1152  (unknown)
    @     0x7f025a799090       3488  (unknown)
    @     0x7f0116af7994        144  torch_xla::(anonymous namespace)::AllReduceInPlace()
    @     0x7f0116b0d463        176  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f0116afbf20        528  pybind11::cpp_function::dispatcher()
    @           0x5d5499  1338095392  PyCFunction_Call
    @           0x8fdb00  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f0116c2be94,7f010e791539,7f025a79908f,7f0116af7993,7f0116b0d462,7f0116afbf1f,5d5498,8fdaff&map=06b7eaee513554b0b69f7d4d65fa69f6858d5374:7f01122e2000-7f0120b21e40,abbd016d9542b8098892badc0b19ea68:7f01015e7000-7f010e9a5cf0 
E1207 22:03:04.196268   59813 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1207 22:03:04.196289   59813 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1207 22:03:04.196341   59813 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1207 22:03:04.196351   59813 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1207 22:03:04.196381   59813 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1207 22:03:04.196397   59813 coredump_hook.cc:603] RAW: Dumping core locally.
E1207 22:03:04.516013   59813 process_state.cc:783] RAW: Raising signal 11 with default behavior
https://symbolize.stripped_domain/r/?trace=7f1c5fad3e94,7f1da364108f,7f1c5f99f993,7f1c5f9b5462,7f1c5f9a3f1f,5d5498,8fdaff&map=06b7eaee513554b0b69f7d4d65fa69f6858d5374:7f1c5b18a000-7f1c699c9e40 
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 58457 (TID 59810) on cpu 91; stack trace: ***
PC: @     0x7f1c5fad3e94  (unknown)  torch_xla::tensor_methods::all_reduce()
    @     0x7f1c5763953a       1152  (unknown)
    @     0x7f1da3641090       3488  (unknown)
    @     0x7f1c5f99f994        144  torch_xla::(anonymous namespace)::AllReduceInPlace()
    @     0x7f1c5f9b5463        176  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f1c5f9a3f20        528  pybind11::cpp_function::dispatcher()
    @           0x5d5499  1338098528  PyCFunction_Call
    @           0x8fdb00  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f1c5fad3e94,7f1c57639539,7f1da364108f,7f1c5f99f993,7f1c5f9b5462,7f1c5f9a3f1f,5d5498,8fdaff&map=06b7eaee513554b0b69f7d4d65fa69f6858d5374:7f1c5b18a000-7f1c699c9e40,abbd016d9542b8098892badc0b19ea68:7f1c4a48f000-7f1c5784dcf0 
E1207 22:03:04.957925   59810 coredump_hook.cc:447] RAW: Remote crash data gathering hook invoked.
E1207 22:03:04.957947   59810 coredump_hook.cc:486] RAW: Skipping coredump since rlimit was 0 at process start.
E1207 22:03:04.957960   59810 client.cc:272] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1207 22:03:04.957967   59810 coredump_hook.cc:542] RAW: Sending fingerprint to remote end.
E1207 22:03:04.957997   59810 coredump_hook.cc:551] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1207 22:03:04.958008   59810 coredump_hook.cc:603] RAW: Dumping core locally.
E1207 22:03:05.231006   59810 process_state.cc:783] RAW: Raising signal 11 with default behavior
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
_pickle.UnpicklingError: pickle data was truncated
https://symbolize.stripped_domain/r/?trace=7fc4a2b7a454,7fc4a2bce08f&map= 
*** SIGTERM received by PID 58455 (TID 58455) on cpu 84 from PID 58367; stack trace: ***
PC: @     0x7fc4a2b7a454  (unknown)  do_futex_wait.constprop.0
    @     0x7fc35407753a       1152  (unknown)
    @     0x7fc4a2bce090  (unknown)  (unknown)
    @ ... and at least 2 more frames
https://symbolize.stripped_domain/r/?trace=7fc4a2b7a454,7fc354077539,7fc4a2bce08f&map=abbd016d9542b8098892badc0b19ea68:7fc346ecd000-7fc35428bcf0 
E1207 22:03:06.356699   58455 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7fe036444454,7fe03649808f&map= 
*** SIGTERM received by PID 58454 (TID 58454) on cpu 47 from PID 58367; stack trace: ***
PC: @     0x7fe036444454  (unknown)  do_futex_wait.constprop.0
    @     0x7fded98f953a       1152  (unknown)
    @     0x7fe036498090  (unknown)  (unknown)
    @ ... and at least 2 more frames
https://symbolize.stripped_domain/r/?trace=7fe036444454,7fded98f9539,7fe03649808f&map=abbd016d9542b8098892badc0b19ea68:7fdecc74f000-7fded9b0dcf0 
E1207 22:03:06.458144   58454 coredump_hook.cc:393] RAW: Remote crash gathering disabled for SIGTERM.
E1207 22:03:06.624897   58455 process_state.cc:783] RAW: Raising signal 15 with default behavior
E1207 22:03:06.646056   58454 process_state.cc:783] RAW: Raising signal 15 with default behavior
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
_pickle.UnpicklingError: pickle data was truncated
Traceback (most recent call last):
  File "/home/mfatih/17_featureMatching/run_train_1_1_TPU_multi.py", line 57, in <module>
    xmp.spawn(train_and_val, args=(FLAGS,) )
  File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 202, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 159, in run_multiprocess
    replica_results = list(
  File "/home/mfatih/env3_8/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 160, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
_pickle.UnpicklingError: pickle data was truncated
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
_pickle.UnpicklingError: pickle data was truncated
JackCaoG commented 10 months ago

I think code crashed in https://github.com/pytorch/xla/blob/a01de3924ca92bbc80b2dee9102b0fbcc236af5c/torch_xla/csrc/tensor_methods.cpp#L339-L367

hard for me to tell where does it crash. Do you have a small repo?

mfatih7 commented 10 months ago

I am working on the repo to share.

mfatih7 commented 10 months ago

Hello @JackCaoG

Here is the repo for debug. Input data is included to enable debugging.

Just update line 21 according to your file system and run run_train_1_1_TPU_multi.py

Please ignore my debug prints at the start. I am working on the code and most probably it has errors.

best regards

mfatih7 commented 10 months ago

Hello @JackCaoG

Is the repo fine for you? I can do anything that can help.

JackCaoG commented 10 months ago

I probably only have bandwidth to handle one issue https://github.com/pytorch/xla/issues/6002, @ManfeiBai do you have cycle to repo this?

ManfeiBai commented 10 months ago

thanks, will do

mfatih7 commented 10 months ago

Hello @ManfeiBai

I can do any modification to the repo that can help.

best regards

ManfeiBai commented 10 months ago

thanks, @mfatih7 ,

I'm trying to repro the info like

I have run commands like:

pip install h5py
pip install opencv-python
apt-get update && apt-get install libgl1
pip install torchsummary
pip install thop
pip install matplotlib
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

and got this info now:

# PJRT_DEVICE=TPU python 17_featureMatching/run_train_1_1_TPU_multi.py
Number of chunks 1
Training starts for train_1_1_each_sample_in_single_batch
TPU:1 DEB PNT 0
TPU:2 DEB PNT 0
TPU:3 DEB PNT 0
Master Print by Process 0 using TPU:0
TPU:0 DEB PNT 0
TPU:2 DEB PNT 1
TPU:2 DEB PNT 2
TPU:0 DEB PNT 1
TPU:0 DEB PNT 2
TPU:3 DEB PNT 1
TPU:3 DEB PNT 2
TPU:1 DEB PNT 1
TPU:1 DEB PNT 2
TPU:0 DEB PNT 3
TPU:2 DEB PNT 3
TPU:3 DEB PNT 3
TPU:1 DEB PNT 3
TPU:2 DEB PNT 4
TPU:0 DEB PNT 4
TPU:1 DEB PNT 4
TPU:2 DEB PNT 4A
TPU:3 DEB PNT 4
TPU:0 DEB PNT 4A
TPU:2 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:1 DEB PNT 4A
TPU:0 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:1 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:2 DEB PNT 4AAA /path/issue6048/17_featureMatching
TPU:3 DEB PNT 4A
TPU:0 DEB PNT 4AAA /path/issue6048/17_featureMatching
TPU:1 DEB PNT 4AAA /path/issue6048/17_featureMatching
TPU:3 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:3 DEB PNT 4AAA /path/issue6048/17_featureMatching
TPU:2 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:3 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:1 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:0 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:2 DEB PNT 4B
TPU:2 DEB PNT 4C
TPU:0 DEB PNT 4B
TPU:0 DEB PNT 4C
TPU:3 DEB PNT 4B
TPU:3 DEB PNT 4C
TPU:1 DEB PNT 4B
TPU:1 DEB PNT 4C
TPU:2 DEB PNT 4D
TPU:3 DEB PNT 4D
TPU:0 DEB PNT 4D
TPU:1 DEB PNT 4D
TPU:2 DEB PNT 4E
TPU:2 DEB PNT 4F
TPU:2 DEB PNT 5
TPU:1 DEB PNT 4E
TPU:1 DEB PNT 4F
TPU:1 DEB PNT 5
TPU:0 DEB PNT 4E
TPU:0 DEB PNT 4F
TPU:0 DEB PNT 5
TPU:3 DEB PNT 4E
TPU:3 DEB PNT 4F
TPU:3 DEB PNT 5
Size of train dataset is 73
Size of train dataset is 73
Size of train dataset is 73
Size of train dataset is 73
https://symbolize.stripped_domain/r/?trace=7f29bff61584,7f2a98a5e13f,7f29bfe2862e,7f29bfe128bd,4fc696&map= 
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 159 (TID 3414) on cpu 239; stack trace: ***
PC: @     0x7f29bff61584  (unknown)  torch_xla::tensor_methods::all_reduce()
    @     0x7f257f2f5067        928  (unknown)
    @     0x7f2a98a5e140       1648  (unknown)
    @     0x7f29bfe2862f        256  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f29bfe128be        512  pybind11::cpp_function::dispatcher()
    @           0x4fc697  (unknown)  cfunction_call
https://symbolize.stripped_domain/r/?trace=7f581b8c1584,7f58f43bb13f,7f581b78862e,7f581b7728bd,4fc696&map= 
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 158 (TID 3402) on cpu 201; stack trace: ***
    @          0x6cc6530  561919264  (unknown)
    @          0x7d91750    5771104  (unknown)
PC: @     0x7f581b8c1584  (unknown)  torch_xla::tensor_methods::all_reduce()
    @     0x7f53df2f5067        928  (unknown)
    @     0x7f58f43bb140       1648  (unknown)
    @     0x7f1a29574c70  (unknown)  (unknown)
    @     0x7f581b78862f        256  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f1a29af5bd0       1920  (unknown)
    @     0x7f581b7728be        512  pybind11::cpp_function::dispatcher()
    @           0x4fc697  (unknown)  cfunction_call
    @          0x619c560  1781461296  (unknown)
    @     0x7f1ac000ebd0       1376  (unknown)
    @     0x7f1ac000f350  (unknown)  (unknown)
    @          0x726b030    3765616  (unknown)
    @     0x7f1ac000f8b0  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f29bff61584,7f257f2f5066,7f2a98a5e13f,7f29bfe2862e,7f29bfe128bd,4fc696,6cc652f,7d9174f,7f1a29574c6f,7f1a29af5bcf,7f1ac000ebcf,    @     0x7f487155a160  1785832720  (unknown)
7f1ac000f34f,7f1ac000f8af&map= 
E1212 20:16:20.468024    3414 coredump_hook.cc:442] RAW: Remote crash data gathering hook invoked.
E1212 20:16:20.468044    3414 client.cc:269] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1212 20:16:20.468049    3414 coredump_hook.cc:537] RAW: Sending fingerprint to remote end.
E1212 20:16:20.468085    3414 coredump_hook.cc:546] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1212 20:16:20.468091    3414 coredump_hook.cc:598] RAW: Dumping core locally.
    @     0x7f48718f16d0       1920  (unknown)
    @     0x7f48dc00bbe0       1376  (unknown)
    @     0x7f48dc00c360  (unknown)  (unknown)
    @     0x7f48dc00c8c0  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f581b8c1584,7f53df2f5066,7f58f43bb13f,7f581b78862e,7f581b7728bd,4fc696,619c55f,726b02f,7f487155a15f,7f48718f16cf,7f48dc00bbdf,7f48dc00c35f,7f48dc00c8bf&map= 
E1212 20:16:20.489174    3402 coredump_hook.cc:442] RAW: Remote crash data gathering hook invoked.
E1212 20:16:20.489187    3402 client.cc:269] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1212 20:16:20.489192    3402 coredump_hook.cc:537] RAW: Sending fingerprint to remote end.
E1212 20:16:20.489213    3402 coredump_hook.cc:546] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1212 20:16:20.489219    3402 coredump_hook.cc:598] RAW: Dumping core locally.
https://symbolize.stripped_domain/r/?trace=7f1c7e14b584,7f1d56c4a13f,7f1c7e01262e,7f1c7dffc8bd,4fc696&map= 
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 160 (TID 3418) on cpu 100; stack trace: ***
PC: @     0x7f1c7e14b584  (unknown)  torch_xla::tensor_methods::all_reduce()
    @     0x7f18432f5067        928  (unknown)
    @     0x7f1d56c4a140       1648  (unknown)
    @     0x7f1c7e01262f        256  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f1c7dffc8be        512  pybind11::cpp_function::dispatcher()
    @           0x4fc697  (unknown)  cfunction_call
    @          0x73beed0  (unknown)  (unknown)
    @          0x84541a0  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f3122cb5584,7f31fb7af13f,7f3122b7c62e,7f3122b668bd,4fc696&map= 
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 155 (TID 3417) on cpu 63; stack trace: ***
    @     0x7f0cd67e6110  (unknown)  (unknown)
PC: @     0x7f3122cb5584  (unknown)  torch_xla::tensor_methods::all_reduce()
    @     0x7f2ce32f5067        928  (unknown)
    @     0x7f31fb7af140       1648  (unknown)
    @     0x7f0cd5af5860       1920  (unknown)
    @     0x7f3122b7c62f        256  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f0d6000fb50       1376  (unknown)
    @     0x7f3122b668be        512  pybind11::cpp_function::dispatcher()
    @           0x4fc697  (unknown)  cfunction_call
    @     0x7f0d600102d0  (unknown)  (unknown)
    @          0x5b4fc90  2129690352  (unknown)
    @     0x7f0d60010830  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f1c7e14b584,7f18432f5066,7f1d56c4a13f,7f1c7e01262e,7f1c7dffc8bd,4fc696,73beecf,845419f,7f0cd67e610f,7f0cd5af585f,7f0d6000fb4f,7f0d600102cf,7f0d6001082f&map= 
E1212 20:16:20.688580    3418 coredump_hook.cc:442] RAW: Remote crash data gathering hook invoked.
E1212 20:16:20.688595    3418 client.cc:269] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1212 20:16:20.688601    3418 coredump_hook.cc:537] RAW: Sending fingerprint to remote end.
E1212 20:16:20.688624    3418 coredump_hook.cc:546] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1212 20:16:20.688629    3418 coredump_hook.cc:598] RAW: Dumping core locally.
    @          0x6bed6b0   11435760  (unknown)
    @     0x7f2185af55a0  (unknown)  (unknown)
    @     0x7f21865dd490       1920  (unknown)
    @     0x7f222800ec50       1376  (unknown)
    @     0x7f222800f3d0  (unknown)  (unknown)
    @     0x7f222800f930  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f3122cb5584,7f2ce32f5066,7f31fb7af13f,7f3122b7c62e,7f3122b668bd,4fc696,5b4fc8f,6bed6af,7f2185af559f,7f21865dd48f,7f222800ec4f,7f222800f3cf,7f222800f92f&map= 
E1212 20:16:20.717141    3417 coredump_hook.cc:442] RAW: Remote crash data gathering hook invoked.
E1212 20:16:20.717155    3417 client.cc:269] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1212 20:16:20.717159    3417 coredump_hook.cc:537] RAW: Sending fingerprint to remote end.
E1212 20:16:20.717180    3417 coredump_hook.cc:546] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E1212 20:16:20.717186    3417 coredump_hook.cc:598] RAW: Dumping core locally.
E1212 20:16:52.659608    3402 process_state.cc:807] RAW: Raising signal 11 with default behavior
E1212 20:16:52.660995    3417 process_state.cc:807] RAW: Raising signal 11 with default behavior
E1212 20:16:52.664395    3414 process_state.cc:807] RAW: Raising signal 11 with default behavior
E1212 20:16:52.667387    3418 process_state.cc:807] RAW: Raising signal 11 with default behavior
Traceback (most recent call last):
  File "/root/issue6048/17_featureMatching/run_train_1_1_TPU_multi.py", line 57, in <module>
    xmp.spawn(train_and_val, args=(FLAGS,) )
  File "/root/miniconda3/envs/torch310/lib/python3.10/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/torch310/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/root/miniconda3/envs/torch310/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/root/miniconda3/envs/torch310/lib/python3.10/site-packages/torch_xla/runtime.py", line 87, in wrapper
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/torch310/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
    replica_results = list(
  File "/root/miniconda3/envs/torch310/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
    itertools.chain.from_iterable(
  File "/root/miniconda3/envs/torch310/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/root/miniconda3/envs/torch310/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/root/miniconda3/envs/torch310/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/root/miniconda3/envs/torch310/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/root/miniconda3/envs/torch310/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
/root/miniconda3/envs/torch310/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 68 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

I have four devices locally like below, and how many devices are supposed to be used in run_train_1_1_TPU_multi.py?:

# PJRT_DEVICE=TPU python
Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla
>>> import torch_xla.runtime as xr
>>> num_devices = xr.global_runtime_device_count()
>>> num_devices
4
>>> 
mfatih7 commented 10 months ago

Hello @ManfeiBai

Thank you for your answer.

The error is similar to the error I get. I am testing with TPU v2 and TPUv3. Both of them have 8 cores. And I think with TPU multi-core operation all of them are used. For your case, 4 cores are used. I think it is also normal.

Python version is different in your case. I am using Python 3.8 and you are using Python 3.10.

I can exclude "TPU:1 DEB PNT 0" debug prints and recommit again. However, it helps me to be sure that multi-core is functioning correctly for generation of datasets, samplers and dataloaders.

Anything I can do to help?

ManfeiBai commented 10 months ago

testing locally again, by debugging with pdb, code seems crashed at xmp.spawn, so test the code before start xmp.spawn with simple code like the following and successed:

$ cat FeatureMatchingDebug/17_featureMatching/run_train_1_1_TPU_multi.py
from config import get_config

from  train_1_1_each_sample_in_single_batch_TPU_multi import train_and_val

import torch_xla.distributed.xla_multiprocessing as xmp

import tpu_related.set_env_variables_for_TPU as set_env_variables_for_TPU

import torch
import torch_xla.core.xla_model as xm

# Wrap most of you main script’s code within if __name__ == '__main__': block, to make sure it doesn’t run again
# (most likely generating error) when each worker process is launched. You can place your dataset and DataLoader
# instance creation logic here, as it doesn’t need to be re-executed in workers.

def _mp_fn(index, flags):
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)

if __name__ == '__main__':

    set_env_variables_for_TPU.set_env_variables_for_TPU_PJRT( )

    config = get_config()

    set_env_variables_for_TPU.set_env_debug_variables_for_TPU_PJRT( config )

    experiment = config.first_experiment

    config.copy_config_file_to_output_folder( experiment )

    # learning_rate = config.learning_rate
    # n_epochs = config.n_epochs
    # num_workers = config.num_workers
    # model_type = config.model_type
    # bn_or_gn = config.bn_or_gn
    # optimizer_type = config.optimizer_types[0]
    # en_grad_checkpointing = config.en_grad_checkpointing
    input_type = config.training_params[0][0]
    N_images_in_batch = config.training_params[0][1]
    N = config.training_params[0][2]
    batch_size = config.training_params[0][3]

    if(input_type=='1_to_1'):

        if( N_images_in_batch >= 1 and N == batch_size ):

            print('Training starts for ' + 'train_1_1_each_sample_in_single_batch')

            FLAGS = {}
            FLAGS['config']                     = config
            FLAGS['learning_rate']              = config.learning_rate
            FLAGS['n_epochs']                   = config.n_epochs
            FLAGS['num_workers']                = config.num_workers
            FLAGS['model_type']                 = config.model_type
            FLAGS['bn_or_gn']                   = config.bn_or_gn
            FLAGS['optimizer_type']             = config.optimizer_types[0]
            FLAGS['en_grad_checkpointing']      = config.en_grad_checkpointing
            FLAGS['input_type']                 = config.training_params[0][0]
            FLAGS['N_images_in_batch']          = config.training_params[0][1]
            FLAGS['N']                          = config.training_params[0][2]
            FLAGS['batch_size']                 = config.training_params[0][3]

            print("num_workers: ", config.num_workers)

            xmp.spawn(_mp_fn, args=(FLAGS,) )
        else:
            raise ValueError(f"The provided arguments are not valid: {input_type} {N_images_in_batch} {N} {batch_size}")
    else:
        raise ValueError(f"The provided argument is not valid: {input_type}") 

successed log:

$ PJRT_DEVICE=TPU python3 FeatureMatchingDebug/17_featureMatching/run_train_1_1_TPU_multi.py
Number of chunks 1
Training starts for train_1_1_each_sample_in_single_batch
num_workers:  3
xla:0
xla:0
xla:1
xla:1
xla:0
xla:1
xla:0
xla:1
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')

tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')

the log shows 8 processes, so modify config of num_workers to 8 and tested again, successed:

$ PJRT_DEVICE=TPU python3 FeatureMatchingDebug/17_featureMatching/run_train_1_1_TPU_multi.py
Number of chunks 1
Training starts for train_1_1_each_sample_in_single_batch
num_workers:  8
xla:1
xla:0
xla:0
xla:1
xla:0
xla:1
xla:0
xla:1
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')
tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:0')

so the next step would check where code crashed in train_and_val func

since i tested test/test_train_mp_mnist.py on TPU v3-8 with py38 successed too, so will try to use test_train_mp_mnist.py here for next step too

mfatih7 commented 10 months ago

Hello @ManfeiBai

Thank you for your answer

Here is more information. I continued to add debug points to the code in the repo I shared. And I realized that the problematic part is the optimizer. If you only comment out the line you will see that the iterations finish without getting the error.

I appreciate any help in solving the problem. Just tell me to do anything to help.

mfatih7 commented 10 months ago

Do you think that the replicas of the model on TPU cores are different somehow?

Here, it is written that

To get consistent parameters between replicas, either use torch_xla.experimental.pjrt.broadcast_master_param to broadcast one replica's parameters to all other replicas, or load each replica's parameters from a common checkpoint.

I am trying to implement the second option : or load each replica's parameters from a common checkpoint

ManfeiBai commented 10 months ago

thanks, my local test crashed at this line xm.optimizer_step(optimizer) too: https://github.com/mfatih7/FeatureMatchingDebug/blob/657ef2ebba333de3ba9a6cd887ce6aa00fbf4d67/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py#L151

thanks, @mfatih7, my next step will change loss func to loss_fn = nn.NLLLoss() to test the difference and confirm how we copy the replicas now

ManfeiBai commented 10 months ago

~thanks, @mfatih7, do you want to try again with modify xm.optimizer_step(optimizer) in FeatureMatchingDebug/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py to xm.mark_step()?

I tested locally with this modify and finished with training and validation log on v4-8:~

# PJRT_DEVICE=TPU python 17_featureMatching/run_train_1_1_TPU_multi.py
Number of chunks 1
Training starts for train_1_1_each_sample_in_single_batch
TPU:3 DEB PNT 0
TPU:2 DEB PNT 0
Master Print by Process 0 using TPU:0
TPU:0 DEB PNT 0
TPU:1 DEB PNT 0
TPU:2 DEB PNT 1
TPU:2 DEB PNT 2
TPU:3 DEB PNT 1
TPU:3 DEB PNT 2
TPU:1 DEB PNT 1
TPU:1 DEB PNT 2
TPU:0 DEB PNT 1
TPU:0 DEB PNT 2
TPU:0 DEB PNT 3
TPU:2 DEB PNT 3
TPU:1 DEB PNT 3
TPU:3 DEB PNT 3
TPU:1 DEB PNT 4
TPU:2 DEB PNT 4
TPU:3 DEB PNT 4
TPU:0 DEB PNT 4
TPU:1 DEB PNT 4A
TPU:2 DEB PNT 4A
TPU:3 DEB PNT 4A
TPU:1 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:2 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:3 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:1 DEB PNT 4AAA /root/FeatureMatchingDebug/17_featureMatching
TPU:2 DEB PNT 4AAA /root/FeatureMatchingDebug/17_featureMatching
TPU:0 DEB PNT 4A
TPU:3 DEB PNT 4AAA /root/FeatureMatchingDebug/17_featureMatching
TPU:1 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:0 DEB PNT 4AA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:2 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:3 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:0 DEB PNT 4AAA /root/FeatureMatchingDebug/17_featureMatching
TPU:0 DEB PNT 4AAAA ../08_featureMatchingOutputs/0100/checkpoints/model.pth.tar
TPU:0 DEB PNT 4B
TPU:0 DEB PNT 4C
TPU:2 DEB PNT 4B
TPU:2 DEB PNT 4C
TPU:3 DEB PNT 4B
TPU:3 DEB PNT 4C
TPU:1 DEB PNT 4B
TPU:1 DEB PNT 4C
TPU:0 DEB PNT 4D
TPU:1 DEB PNT 4D
TPU:2 DEB PNT 4D
TPU:3 DEB PNT 4D
TPU:0 DEB PNT 4E
TPU:0 DEB PNT 4F
TPU:0 DEB PNT 5
TPU:1 DEB PNT 4E
TPU:1 DEB PNT 4F
TPU:1 DEB PNT 5
TPU:3 DEB PNT 4E
TPU:3 DEB PNT 4F
TPU:3 DEB PNT 5
TPU:2 DEB PNT 4E
TPU:2 DEB PNT 4F
TPU:2 DEB PNT 5
Size of train dataset is 73
Size of train dataset is 73
Size of train dataset is 73
Size of train dataset is 73
Train Epoch 0/1 Chunk 0/0 Batch 72/72 LR 0.010000 lCls 0.692904 lGeo 0.089983 LEss 1.477992 CorPred 21824/37376 Acc 0.583904 Pre 0.109663 Rec 0.424031 F1 0.174259
Size of val dataset is 122
Size of val dataset is 122
Size of val dataset is 122
Size of val dataset is 122
Val Epoch 0/1 Chunk 0/0 Batch 121/121 LR 0.010000 LossCls 0.691420 lGeo 0.089422 LEss 1.526129 CorPred 36165/62464 Acc 0.578973 Pre 0.121050 Rec 0.429072 F1 0.188828
Size of train dataset is 73
Size of train dataset is 73
Size of train dataset is 73
Size of train dataset is 73
Train Epoch 1/1 Chunk 0/0 Batch 72/72 LR 0.010000 lCls 0.692904 lGeo 0.089983 LEss 1.477992 CorPred 21824/37376 Acc 0.583904 Pre 0.109663 Rec 0.424031 F1 0.174259
Size of val dataset is 122
Size of val dataset is 122
Size of val dataset is 122
Size of val dataset is 122
Val Epoch 1/1 Chunk 0/0 Batch 121/121 LR 0.010000 LossCls 0.691420 lGeo 0.089422 LEss 1.526129 CorPred 36165/62464 Acc 0.578973 Pre 0.121050 Rec 0.429072 F1 0.188828
# git status
On branch main
Your branch is up to date with 'origin/main'.

Changes not staged for commit:
  (use "git add <file>..." to update what will be committed)
  (use "git restore <file>..." to discard changes in working directory)
    modified:   17_featureMatching/config.py
    modified:   17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py

Untracked files:
  (use "git add <file>..." to include in what will be committed)
    08_featureMatchingOutputs/

no changes added to commit (use "git add" and/or "git commit -a")
# git diff
error: cannot run less: No such file or directory
diff --git a/17_featureMatching/config.py b/17_featureMatching/config.py
index a97f14e..44d5416 100644
--- a/17_featureMatching/config.py
+++ b/17_featureMatching/config.py
@@ -18,7 +18,7 @@ class Config:
             self.storage_local_or_bucket = 'local'     

         if( self.device == 'tpu' ):
-            os.chdir( os.path.join('/', 'home', 'mfatih', 'FeatureMatchingDebug', '17_featureMatching') )
+            os.chdir( os.path.join('/', 'root', 'FeatureMatchingDebug', '17_featureMatching') )

         self.first_experiment = 100

diff --git a/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py b/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py
index bf5ba4b..c408da9 100644
--- a/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py
+++ b/17_featureMatching/train_1_1_each_sample_in_single_batch_TPU_multi.py
@@ -148,8 +148,9 @@ def train_and_val(index, FLAGS):

                 loss.backward()

-                xm.optimizer_step(optimizer)
-                    
+                # xm.optimizer_step(optimizer)
+                xm.mark_step()
+
                 confusion_matrix_at_epoch_train_device[0,0] += torch.sum( torch.logical_and( logits<0, labels_device>config.obj_geod_th ) )
                 confusion_matrix_at_epoch_train_device[0,1] += torch.sum( torch.logical_and( logits>0, labels_device>config.obj_geod_th ) )
                 confusion_matrix_at_epoch_train_device[1,0] += torch.sum( torch.logical_and( logits<0, labels_device<config.obj_geod_th ) )

~the reason is v2-8/v3-8 are single-host devices, so we use xm.mark_step() like this single-host example.~

this solution is not for v2-8/v3-8 with multi core, @mfatih7 give the right solution below
mfatih7 commented 10 months ago

Hello @ManfeiBai

Thank you very much for your effort.

But I think something is wrong. According to my knowledge TPUv2 and TPUv3 devices have 8 cores. We can use either one of the cores or 8 of the cores together for faster processing.

Disabling xm.optimizer_step(optimizer) and replacing it with xm.mark_step() is not a correct solution. xm.optimizer_step(optimizer) internally calls xm.mark_step().

I think I have found the solution to the problem. The solution is adding model = model.to(device) to the line.

I think before saving the initial checkpoint the model must be in the device.

ManfeiBai commented 10 months ago

thanks, @mfatih7, you are right, we should use xm.optimizer_step(optimizer) on multi-processing, and glad to see you find the solution, I verified your solution on my v3-8 and code finished with training/validation log too, so we might want to mark this issue as closed?

mfatih7 commented 10 months ago

Thank you @ManfeiBai

Regarding this post, I want to ask a final question.

On line, I am loading the checkpoint and updating model parameters on all cores.

Is this the correct option?

Because here it is written that

To get consistent parameters between replicas, either use torch_xla.experimental.pjrt.broadcast_master_param to broadcast one replica's parameters to all other replicas, or load each replica's parameters from a common checkpoint.

Should I use xm.broadcast_master_param(model) as here?

ManfeiBai commented 9 months ago

thanks, good question, synced with @JackCaoG offline, for question 1, loading the checkpoint and updating model parameters on all cores in line is the correct option,

for question 2, current code we don't need to run xm.broadcast_master_param(model) because xm.save used master_only=True, so all cores load data from master process saved

for situation xm.save(..., master_only=False), we could use xm.broadcast_master_param(model) to make sure all process has the same parameters, and for situation that different cores has different parameters, xm.broadcast_master_param(model) could help to get consistent parameters between replicas

mfatih7 commented 9 months ago

Thank you for your effort @ManfeiBai