Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

TypeError: unwrap_if_dead() incompatible args in NeVA model #660

Closed tfogal closed 3 months ago

tfogal commented 3 months ago

🐛 Bug

First install the tfogal/thunder-nemo branch from https://github.com/tfogal/NeMo. Then run:

HYDRA_FULL_ERROR=1 \
THUNDER_ANNOTATE_TRACES=1 \
NEMO_THUNDER_NEVA=1 \
python3 ./examples/multimodal/multimodal_llm/neva/neva_pretrain.py trainer.precision=16 model.megatron_amp_O2=False trainer.num_nodes=1 trainer.devices=1 trainer.val_check_interval=10 trainer.limit_val_batches=5 trainer.log_every_n_steps=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=2 model.global_batch_size=4 model.tensor_model_parallel_size=1 model.pipeline_model_parallel_size=1 exp_manager.create_checkpoint_callback=False model.data.data_path=./data/multimodal/tiny-neva/dummy.json model.data.image_folder=./data/multimodal/tiny-neva/images model.tokenizer.library=sentencepiece model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model model.num_layers=2 model.hidden_size=5120 model.ffn_hidden_size=13824 model.num_attention_heads=40 model.normalization=rmsnorm model.data.num_workers=0 model.data.conv_template=llama_2 model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14 model.mm_cfg.llm.from_pretrained=null model.use_flash_attention=false exp_manager.exp_dir=./foo-neva-train

After ~15--20 seconds this produces:

Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]Error executing job with overrides: ['trainer.precision=16', 'model.megatron_amp_O2=False', 'trainer.num_nodes=1', 'trainer.devices=1', 'trainer.val_check_interval=10', 'trainer.limit_val_batches=5', 'trainer.log_every_n_steps=1', '++exp_manager.max_time_per_run=00:00:03:00', 'trainer.max_steps=20', 'model.micro_batch_size=2', 'model.global_batch_size=4', 'model.tensor_model_parallel_size=1', 'model.pipeline_model_parallel_size=1', 'exp_manager.create_checkpoint_callback=False', 'model.data.data_path=./data/multimodal/tiny-neva/dummy.json', 'model.data.image_folder=./data/multimodal/tiny-neva/images', 'model.tokenizer.library=sentencepiece', 'model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model', 'model.num_layers=2', 'model.hidden_size=5120', 'model.ffn_hidden_size=13824', 'model.num_attention_heads=40', 'model.normalization=rmsnorm', 'model.data.num_workers=0', 'model.data.conv_template=llama_2', 'model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14', 'model.mm_cfg.llm.from_pretrained=null', 'model.use_flash_attention=false', 'exp_manager.exp_dir=./foo-neva-train']
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 51, in <module>
[rank0]:     main()
[rank0]:   File "/home/tfogal/dev/nemo/nemo/core/config/hydra_runner.py", line 129, in wrapper
[rank0]:     _run_hydra(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
[rank0]:     _run_app(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
[rank0]:     run_and_report(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
[rank0]:     raise ex
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
[rank0]:     return func()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
[rank0]:     lambda: hydra.run(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
[rank0]:     _ = ret.return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
[rank0]:     raise self._return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
[rank0]:     ret.return_value = task_function(task_cfg)
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 45, in main
[rank0]:     trainer.fit(model)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
[rank0]:     self._run_sanity_check()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
[rank0]:     val_loop.run()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
[rank0]:     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
[rank0]:     output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 410, in validation_step
[rank0]:     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
[rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/parallel/distributed.py", line 1636, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/parallel/distributed.py", line 1454, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
[rank0]:     out = method(*_args, **_kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 892, in validation_step
[rank0]:     return MegatronGPTModel.validation_step(self, dataloader_iter)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1370, in validation_step
[rank0]:     loss = self.fwd_bwd_step(dataloader_iter, True, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 658, in fwd_bwd_step
[rank0]:     return MegatronGPTModel.fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 684, in fwd_bwd_step
[rank0]:     losses_reduced_per_micro_batch = fwd_bwd_function(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 381, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 206, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 827, in fwd_output_and_loss_func
[rank0]:     output_tensor = model(**forward_args)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/module.py", line 61, in forward
[rank0]:     res = self._forward_fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 676, in fn_
[rank0]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 224, in cache_info_wrapper
[rank0]:     res = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
[rank0]:     jit_results: TraceResults = interpreter(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 212, in _general_frontend
[rank0]:     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1657, in thunder_general_jit
[rank0]:     result = jfn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6696, in fn_
[rank0]:     raise e
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6664, in fn_2
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 461, in forward
[rank0]:     result = GPTModel.forward(self, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py", line 280, in forward
[rank0]:     lm_output = self.language_model(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 764, in forward
[rank0]:     encoder_input = self.embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 348, in forward
[rank0]:     words_embeddings = self.word_embeddings(input_ids)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1566, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/nn/modules/module.py", line 1575, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 162, in forward
[rank0]:     words_embeddings = super().forward(input_ids, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py", line 245, in forward
[rank0]:     output = reduce_from_tensor_model_parallel_region(output_parallel)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/mappings.py", line 446, in reduce_from_tensor_model_parallel_region
[rank0]:     return _ReduceFromModelParallelRegion.apply(input_)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/autograd/function.py", line 573, in apply
[rank0]:     args = _functorch.utils.unwrap_dead_wrappers(args)
[rank0]:   File "/home/tfogal/dev/pytorch/torch/_functorch/utils.py", line 33, in unwrap_dead_wrappers
[rank0]:     result = tuple(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 2729, in <lambda>
[rank0]:     wv = _interpret_call(lambda iterator: iterator.__next__(), iterator)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6189, in _call_dispatch
[rank0]:     opaque_result: Any = fn(*args_, **kwargs_)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 5895, in thunder_interpreter_generator
[rank0]:     raise e
[rank0]:   File "/home/tfogal/dev/pytorch/torch/_functorch/utils.py", line 34, in <genexpr>
[rank0]:     unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6189, in _call_dispatch
[rank0]:     opaque_result: Any = fn(*args_, **kwargs_)
[rank0]: TypeError: unwrap_if_dead(): incompatible function arguments. The following argument types are supported:
[rank0]:     1. (arg0: torch.Tensor) -> torch.Tensor

[rank0]: Invoked with: t66

Environment

$ for prj in nvfuser torch lightning ; do python3 -m pip freeze | grep -i ${prj} ; done
# Editable Git install with no remote (nvfuser==0.2.6+git9c5f006)
-e /opt/pytorch/nvfuser
-e /opt/pytorch/nvfuser
open-clip-torch==2.24.0
pytorch-lightning==2.3.0
-e git+https://github.com/pytorch/pytorch.git@bd72e28314d8d63bb347becb8309f5ac7761c6b5#egg=torch
torchdiffeq==0.2.4
torchmetrics==1.4.0.post0
torchsde==0.2.6
torchvision @ git+https://github.com/pytorch/vision.git@bf01bab6125c5f1152e4f336b470399e52a8559d
-e [snip, an internal nvidia package with a name matching the greps]
-e git+ssh://git@github.com/tfogal/lightning.git@8df5db52ead1804f9021bb07caa2d4a7a6ab03a1#egg=lightning
lightning-cloud==0.5.69
-e git+ssh://git@github.com/Lightning-AI/lightning-thunder.git@754af861e6ac6de4174147ee7a2250b5286229ec#egg lightning_thunder
lightning-utilities==0.11.2
pytorch-lightning==2.3.0
$ nvidia-smi | grep -i CUDA
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |

Additional context

unwrap_if_dead appears to be part of functorch.

cc @tfogal

t-vi commented 3 months ago

Staring down the stacktrace in the middle of the night, this seems to be part of MegaTron's TensorParallel applied to embeddings and the question would be what our expectation is (up to now, I think we have tried to run models without their own parallelization and then set up Thunder's; we might also try running through it with tensorparallel enabled, but then maybe the more natural thing to divert is reduce_from_tensor_model_parallel_region or the layer class's forwad method).

[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 162, in forward
[rank0]:     words_embeddings = super().forward(input_ids, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py", line 245, in forward
[rank0]:     output = reduce_from_tensor_model_parallel_region(output_parallel)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/mappings.py", line 446, in reduce_from_tensor_model_parallel_region
[rank0]:     return _ReduceFromModelParallelRegion.apply(input_)
t-vi commented 3 months ago

As @crcrpar points out, further down the stack, this is called by autograd.Function:

[rank0]:   File "/home/tfogal/dev/pytorch/torch/autograd/function.py", line 573, in apply
[rank0]:     args = _functorch.utils.unwrap_dead_wrappers(args)

so we might want to allow tracing through the autograd.Function.apply and forward (and hoping to do the backward with standard autodiff / registrations of our own). Or we might treat forward and backward as black-box units and include them in the trace as symbols, which would be staying to the user code. Or we could try to make it our own thing with the thunder derivative transform mechnism.