Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k stars 63 forks source link

Hang using thunder.jit with tokenizer in NeMo Stable Diffusion #462

Open athitten opened 2 months ago

athitten commented 2 months ago

🐛 Bug

NeMo's Stable Diffusion uses CLIPTokenizer from HuggingFace. Adding thunder.jit to the tokenizer is causing a hang.

To Reproduce

tokenizer.patch

Steps to reproduce the behavior:

  1. Apply the attached git diff to NeMo
  2. Run NeMo using the command below:
    python examples/multimodal/text_to_image/stable_diffusion/sd_train.py trainer.precision=16 trainer.num_nodes=1 trainer.devices=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=1 model.global_batch_size=1 model.data.synthetic_data=True exp_manager.exp_dir=/workspace/TestData/multimodal/stable_diffusion_train model.inductor=False model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder ++model.cond_stage_config.version=openai/clip-vit-large-patch14 ++model.cond_stage_config.max_length=77 ~model.cond_stage_config.restore_from_path ~model.cond_stage_config.freeze ~model.cond_stage_config.layer model.unet_config.from_pretrained=null model.first_stage_config.from_pretrained=null model.unet_config.use_flash_attention=False model.unet_config.attention_resolutions=\[1\] model.unet_config.channel_mult=\[1\]

    CC: @tfogal

cc @tfogal

athitten commented 1 month ago

Digging deeper into this showed that the tokenizer does not consist of any torch.nn.module and just takes care of tokenization. Guessing that this could be the reason thunder.jit lead to a hang.

tfogal commented 1 month ago

Digging deeper into this showed that the tokenizer does not consist of any torch.nn.module and just takes care of tokenization. Guessing that this could be the reason thunder.jit lead to a hang.

Ack, that's a good theory. Seems like we should have tests where the computation trace ends up empty, though... I wonder if there's comms inside the tokenizer and thunder.jit somehow interferes (or reorders?) them?

For now, we can just avoid using thunder for this part of the model.

Regardless, thunder shouldn't hang. We'll need to reduce this down to something the thunder team can work with. Let's talk offline on that.

tfogal commented 1 month ago

I attached a debugger and found 95 total threads, but most had uninteresting backtraces, like threads 87--94:

(gdb) bt
#0  0x000075791e330117 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x000075791e332a41 in pthread_cond_wait () from /usr/lib/x86_64-linux-gnu/libc.so.6
#2  0x00007578d8031d57 in c10::ThreadPool::main_loop(unsigned long) ()
   from /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so
#3  0x000075791e14f253 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#4  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#5  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

Based on the naming, it seems like there's a threadpool and they're all waiting to be woken up.

The most interesting thread was thread 2, frozen inside UCX:

(gdb) thr 2
[Switching to thread 2 (Thread 0x75786a7ff640 (LWP 51477))]
#0  0x000075791e3b381c in read () from /usr/lib/x86_64-linux-gnu/libc.so.6
(gdb) bt
#0  0x000075791e3b381c in read () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x0000757911caa98d in read (__nbytes=271, __buf=0x75786a7fc370, __fd=<optimized out>)
    at /usr/include/x86_64-linux-gnu/bits/unistd.h:38
#2  ucs_vfs_fuse_wait_for_path (path=path@entry=0x75786a7fd512 "/tmp/ucx-vfs-tfogal.sock") at vfs_fuse.c:303
#3  0x0000757911caaccc in ucs_vfs_fuse_thread_func (arg=<optimized out>) at vfs_fuse.c:409
#4  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#5  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

Threads 3 and 56 were stuck somewhere in cuda:

(gdb) bt
#0  0x000075791e3b7bcf in poll () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x0000757864016e4f in ?? () from /usr/local/cuda/compat/lib.real/libcuda.so.1
#2  0x00007578640dcf7f in ?? () from /usr/local/cuda/compat/lib.real/libcuda.so.1
#3  0x000075786400f3f3 in ?? () from /usr/local/cuda/compat/lib.real/libcuda.so.1
#4  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#5  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

Thread 17 was seemingly waiting for a socket connection or to recv something:

#0  0x000075791e3b7bcf in poll () from /usr/lib/x86_64-linux-gnu/libc.so.6
(gdb) bt
#0  0x000075791e3b7bcf in poll () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x00007579099aa19d in c10d::detail::TCPStoreMasterDaemon::run() ()
   from /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
#2  0x000075791e14f253 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#3  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#4  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

Threads 20, 24, 38, probably similar to thread 17, but via a different route:

(gdb) bt
#0  0x000075791e3c4e2e in epoll_wait () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x000075790b56b282 in gloo::transport::tcp::Loop::run() () from /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
#2  0x000075791e14f253 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#3  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#4  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

Threads 57, 58, 62, 63, 66, 67, in NCCL:

(gdb) bt
#0  0x000075791e3b7bcf in poll () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x00007578990749c2 in ?? () from /usr/lib/x86_64-linux-gnu/libnccl.so.2
#2  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#3  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

Threads 61 and 65 in accept:

(gdb) bt
#0  0x000075791e3c645f in accept () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x00007578990ac511 in ?? () from /usr/lib/x86_64-linux-gnu/libnccl.so.2
#2  0x00007578990ad18b in ?? () from /usr/lib/x86_64-linux-gnu/libnccl.so.2
#3  0x000075789903ebfd in ?? () from /usr/lib/x86_64-linux-gnu/libnccl.so.2
#4  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#5  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

CPU usage was nil, so it's most likely this was actually blocked as opposed to me catching it at a weird time.

Thread 69 in a poll from deep within a Python _PyEval_EvalFrameDefault.

Threads 70--86 acquiring a Python lock:

(gdb) bt
#0  0x000075791e330117 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x000075791e33bc78 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#2  0x0000630aa98d45d0 in PyThread_acquire_lock_timed ()
#3  0x0000630aa992c461 in ?? ()
#4  0x0000630aa99168b7 in ?? ()
#5  0x0000630aa98fe45c in _PyEval_EvalFrameDefault ()
#6  0x0000630aa99237f1 in ?? ()
#7  0x0000630aa98fe26d in _PyEval_EvalFrameDefault ()
#8  0x0000630aa99159fc in _PyFunction_Vectorcall ()

Thread 95 was apparently waiting to do autograd:

(gdb) bt
#0  0x000075791e330117 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x000075791e332a41 in pthread_cond_wait () from /usr/lib/x86_64-linux-gnu/libc.so.6
#2  0x0000757908f8d49b in torch::autograd::ReadyQueue::pop() ()
   from /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
#3  0x0000757908f9146d in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) ()
   from /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
#4  0x0000757908f8a0e9 in torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
#5  0x0000757910f771b5 in torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) () from /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so
#6  0x000075791e14f253 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#7  0x000075791e333ac3 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#8  0x000075791e3c5850 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6

If anyone wants to try reproducing this, note that I needed to add:

diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py
index 8048ee7d..00abfc09 100644
--- a/thunder/torch/__init__.py
+++ b/thunder/torch/__init__.py
@@ -177,6 +177,11 @@ def is_floating_point(a: TensorLike, /) -> bool:
     return dtypes.is_float_dtype(a.dtype)

+@torchsymbol(torch.Tensor.is_contiguous, is_method=True)
+def is_contiguous(a: TensorLike, memory_format: torch.memory_format) -> bool:
+    return True # hack
+
+
 # Handles the size method
 def size(a: TensorLike, /, dim: None | int = None) -> int | Sequence[int]:
     if dim is not None:

to thunder to reach this point.

tfogal commented 1 month ago

Following where thread 2 is blocked, and especially the comment up a few lines, it appears that UCX is waiting for a daemon to be started that never starts up. It's not clear whether UCS is a red herring or actually relevant; since the process finishes when not using thunder.jit, it could mean either 1) thunder somehow prevents the daemon from starting, or 2) the daemon is irrelevant; even in the normal case this thread gets stuck, and its progress is not relevant to the test job.

Based on the above, my current theory is that:

tfogal commented 1 month ago

We also see comms from all of gloo; some TCPStoreMasterDaemon thing that appears to be built-in to PyTorch; NCCL; and UCX. Plus a ton of threads that makes this hard to debug.

Tagging @eqy for some debugging help:

  1. Is there a way (env var, compile option, something I can manually #if 0) to disable some of these comm sources in PyTorch (to help narrow this down)? Ideally I'd get this down to just a NCCL thing.
  2. Is there a way to force torch to reap threads immediately, or just not bother with a thread pool at all?
  3. Is there any feasible way to go from a backtrace in gdb to whatever line in Python that corresponds to (for the frames that are Python-related)?
mruberry commented 1 month ago

triage review —

@IvanYashchuk @t-vi to follow-up

eqy commented 1 month ago

We also see comms from all of gloo; some TCPStoreMasterDaemon thing that appears to be built-in to PyTorch; NCCL; and UCX. Plus a ton of threads that makes this hard to debug.

Tagging @eqy for some debugging help:

  1. Is there a way (env var, compile option, something I can manually #if 0) to disable some of these comm sources in PyTorch (to help narrow this down)? Ideally I'd get this down to just a NCCL thing.
  2. Is there a way to force torch to reap threads immediately, or just not bother with a thread pool at all?
  3. Is there any feasible way to go from a backtrace in gdb to whatever line in Python that corresponds to (for the frames that are Python-related)?
  1. Just seeing this now, all I know of is USE_UCC=0 to disable the backend at build time.

  2. I'm not sure which threadpool this is, and set_num_threads doesn't seem to be relevant here

  3. pdb doesn't seem like it would be useful here, but I wonder if a hack like https://stackoverflow.com/questions/1032813/dump-stacktraces-of-all-active-threads/24334576#24334576 plus a signal handler could dump the relevant stack trace when the hang is reproduced