Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 74 forks source link

Interpreter error in pure-thunder NeVA path: non-`WrappedValue` in args #891

Closed tfogal closed 1 month ago

tfogal commented 1 month ago

🚀 Model / language coverage

[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1226, in default_opcode_interpreter
[rank0]:     return handler(inst, **interpreter_state)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 3675, in _call_function_ex_handler
[rank0]:     return check_and_append(stack, _interpret_call(func, *args, **kwargs))
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6315, in _interpret_call
[rank0]:     rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6374, in _call_dispatch
[rank0]:     return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6315, in _interpret_call
[rank0]:     rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6533, in _call_dispatch
[rank0]:     return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6735, in _setup_frame_and_run_python_function
[rank0]:     raise e.with_traceback(tb)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6372, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6730, in _setup_frame_and_run_python_function
[rank0]:     res, status = _run_frame(frame, compilectx, runtimectx)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6779, in _run_frame
[rank0]:     interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 412, in interpret
[rank0]:     return self._opcode_interpreter(inst, **interpreter_state)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1226, in default_opcode_interpreter
[rank0]:     return handler(inst, **interpreter_state)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 3675, in _call_function_ex_handler
[rank0]:     return check_and_append(stack, _interpret_call(func, *args, **kwargs))
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6315, in _interpret_call
[rank0]:     rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6476, in _call_dispatch
[rank0]:     res = lookaside_fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 930, in _general_jit_torch_autograd_function_apply_lookaside
[rank0]:     return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6315, in _interpret_call
[rank0]:     rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6354, in _call_dispatch
[rank0]:     assert all(isinstance(a, WrappedValue) for a in args)
[rank0]: AssertionError

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 120, in <module>
[rank0]:     main()
[rank0]:   File "/home/tfogal/dev/nemo/nemo/core/config/hydra_runner.py", line 129, in wrapper
[rank0]:     _run_hydra(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
[rank0]:     _run_app(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
[rank0]:     run_and_report(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
[rank0]:     raise ex
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
[rank0]:     return func()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
[rank0]:     lambda: hydra.run(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
[rank0]:     _ = ret.return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
[rank0]:     raise self._return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
[rank0]:     ret.return_value = task_function(task_cfg)
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 111, in main
[rank0]:     trainer.fit(model)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
[rank0]:     self._run_sanity_check()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
[rank0]:     val_loop.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
[rank0]:     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
[rank0]:     output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 410, in validation_step
[rank0]:     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
[rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1640, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1456, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
[rank0]:     out = method(*_args, **_kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 897, in validation_step
[rank0]:     return MegatronGPTModel.validation_step(self, dataloader_iter)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1370, in validation_step
[rank0]:     loss = self.fwd_bwd_step(dataloader_iter, True, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 665, in fwd_bwd_step
[rank0]:     return MegatronGPTModel.fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 684, in fwd_bwd_step
[rank0]:     losses_reduced_per_micro_batch = fwd_bwd_function(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 395, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 219, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 832, in fwd_output_and_loss_func
[rank0]:     output_tensor = model(**forward_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/module.py", line 61, in forward
[rank0]:     res = self._forward_fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 696, in fn_
[rank0]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 223, in cache_info_wrapper
[rank0]:     res = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 494, in get_computation_and_inputs
[rank0]:     jit_results: TraceResults = interpreter(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 211, in _general_frontend
[rank0]:     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
[rank0]:     result = jfn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 7072, in fn_
[rank0]:     raise InterpreterError(msg) from e
[rank0]: thunder.core.interpreter.InterpreterError: Encountered exception AssertionError:  while tracing NevaModel(
[rank0]:   (language_model): TransformerLanguageModel(
[rank0]:     (embedding): Embedding(
[rank0]:       (word_embeddings): VocabParallelEmbedding(
[rank0]:         (vision_encoder): CLIPVisionModel(
[rank0]:           (vision_model): CLIPVisionTransformer(
[rank0]:             (embeddings): CLIPVisionEmbeddings(
[rank0]:               (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
[rank0]:               (position_embedding): Embedding(257, 1024)
[rank0]:             )
[rank0]:             (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:             (encoder): CLIPEncoder(
[rank0]:               (layers): ModuleList(
[rank0]:                 (0-23): 24 x CLIPEncoderLayer(
[rank0]:                   (self_attn): CLIPAttention(
[rank0]:                     (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                     (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                     (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                     (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                   )
[rank0]:                   (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:                   (mlp): CLIPMLP(
[rank0]:                     (activation_fn): QuickGELUActivation()
[rank0]:                     (fc1): Linear(in_features=1024, out_features=4096, bias=True)
[rank0]:                     (fc2): Linear(in_features=4096, out_features=1024, bias=True)
[rank0]:                   )
[rank0]:                   (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:                 )
[rank0]:               )
[rank0]:             )
[rank0]:             (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:           )
[rank0]:         )
[rank0]:         (adapter_layer): ModuleDict(
[rank0]:           (mm_projector_adapter): MultimodalProjectorAdapter(
[rank0]:             (mm_projector): Linear(in_features=1024, out_features=5120, bias=True)
[rank0]:           )
[rank0]:         )
[rank0]:       )
[rank0]:       (embedding_dropout): Dropout(p=0.0, inplace=False)
[rank0]:     )
[rank0]:     (rotary_pos_emb): RotaryEmbedding()
[rank0]:     (encoder): ParallelTransformer(
[rank0]:       (layers): ModuleList(
[rank0]:         (0-1): 2 x ParallelTransformerLayer(
[rank0]:           (input_layernorm): MixedFusedRMSNorm(torch.Size([5120]), eps=1e-05, elementwise_affine=True)
[rank0]:           (self_attention): ParallelAttention(
[rank0]:             (query_key_value): ColumnParallelLinear()
[rank0]:             (core_attention): CoreAttention(
[rank0]:               (scale_mask_softmax): MatchedScaleMaskSoftmax()
[rank0]:               (attention_dropout): Dropout(p=0.0, inplace=False)
[rank0]:             )
[rank0]:             (dense): RowParallelLinear()
[rank0]:           )
[rank0]:           (post_attention_layernorm): MixedFusedRMSNorm(torch.Size([5120]), eps=1e-05, elementwise_affine=True)
[rank0]:           (mlp): ParallelMLP(
[rank0]:             (dense_h_to_4h): ColumnParallelLinear()
[rank0]:             (dense_4h_to_h): RowParallelLinear()
[rank0]:           )
[rank0]:         )
[rank0]:       )
[rank0]:       (final_layernorm): MixedFusedRMSNorm(torch.Size([5120]), eps=1e-05, elementwise_affine=True)
[rank0]:     )
[rank0]:     (output_layer): ColumnParallelLinear()
[rank0]:   )
[rank0]:   (embedding): Embedding(
[rank0]:     (word_embeddings): VocabParallelEmbedding(
[rank0]:       (vision_encoder): CLIPVisionModel(
[rank0]:         (vision_model): CLIPVisionTransformer(
[rank0]:           (embeddings): CLIPVisionEmbeddings(
[rank0]:             (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
[rank0]:             (position_embedding): Embedding(257, 1024)
[rank0]:           )
[rank0]:           (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:           (encoder): CLIPEncoder(
[rank0]:             (layers): ModuleList(
[rank0]:               (0-23): 24 x CLIPEncoderLayer(
[rank0]:                 (self_attn): CLIPAttention(
[rank0]:                   (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                   (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                   (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                   (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
[rank0]:                 )
[rank0]:                 (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:                 (mlp): CLIPMLP(
[rank0]:                   (activation_fn): QuickGELUActivation()
[rank0]:                   (fc1): Linear(in_features=1024, out_features=4096, bias=True)
[rank0]:                   (fc2): Linear(in_features=4096, out_features=1024, bias=True)
[rank0]:                 )
[rank0]:                 (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:               )
[rank0]:             )
[rank0]:           )
[rank0]:           (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
[rank0]:         )
[rank0]:       )
[rank0]:       (adapter_layer): ModuleDict(
[rank0]:         (mm_projector_adapter): MultimodalProjectorAdapter(
[rank0]:           (mm_projector): Linear(in_features=1024, out_features=5120, bias=True)
[rank0]:         )
[rank0]:       )
[rank0]:     )
[rank0]:     (embedding_dropout): Dropout(p=0.0, inplace=False)
[rank0]:   )
[rank0]: ):

Full log of the failing run

Pitch

This comes up when trying to support NeVA on the pure-thunder path (i.e. no dynamo frontend).

Alternatives / Potential work-arounds

We could just use the dynamo frontend, for now.

Minimal Repro

Still working on this...

cc @tfogal

t-vi commented 1 month ago

The key bits in the traceback are

[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 227, in mixed_dtype_fused_rms_norm_affine
[rank0]:     return FusedRMSNormAffineMixedDtypesFunction.apply(*args)

and

[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 930, in _general_jit_torch_autograd_function_apply_lookaside
[rank0]:     return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_)

The problem is that we should not use tree map here, but just pass the list and dictionary to not descend into the tuple. I'll send a PR.

https://github.com/Lightning-AI/lightning-thunder/blob/d202ba3d868642e0b4d61cd1bf93794350d5a663/thunder/core/jit_ext.py#L928-L930