Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

Dtype mismatch in cat: bfloat16 and float16 #812

Closed tfogal closed 3 months ago

tfogal commented 3 months ago

🚀 Model / language coverage

First, I applied this diff to thunder:

diff --git a/thunder/core/utils.py b/thunder/core/utils.py
index 271dcdf3..346f264e 100644
--- a/thunder/core/utils.py
+++ b/thunder/core/utils.py
@@ -237,6 +237,10 @@ def check_same_dtype(*args):
             if dtype is None:
                 dtype = typ

+            if not are_same_dtypes(dtype, typ):
+                import traceback
+                print(f"mismatched types: {dtype}, {typ}")
+                traceback.print_stack()
             check(
                 are_same_dtypes(dtype, typ),
                 lambda: f"Expected dtype {dtype} but found {typ}!",

The diff was necessary to get the beginning of the output below, which conveys that a cat operator is what is at fault:

  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 812, in cat
    return clang.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/clang/__init__.py", line 1289, in cat
    return prims.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 272, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/prims.py", line 2983, in cat_meta
    utils.check_same_dtype(*tensors)
  File "/home/tfogal/dev/thunder/thunder/core/utils.py", line 243, in check_same_dtype
    traceback.print_stack()
mismatched types: thunder.dtypes.bfloat16, thunder.dtypes.float16
Error executing job with overrides: ['trainer.precision=16', 'model.megatron_amp_O2=False', 'trainer.num_nodes=1', 'trainer.devices=1', 'trainer.val_check_interval=10', 'trainer.limit_val_batches=5', 'trainer.log_every_n_steps=1', '++exp_manager.max_time_per_run=00:00:03:00', 'trainer.max_steps=20', 'model.micro_batch_size=2', 'model.global_batch_size=4', 'model.tensor_model_parallel_size=1', 'model.pipeline_model_parallel_size=1', 'exp_manager.create_checkpoint_callback=False', 'model.data.data_path=./data/multimodal/tiny-neva/dummy.json', 'model.data.image_folder=./data/multimodal/tiny-neva/images', 'model.tokenizer.library=sentencepiece', 'model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model', 'model.num_layers=2', 'model.hidden_size=5120', 'model.ffn_hidden_size=13824', 'model.num_attention_heads=40', 'model.normalization=rmsnorm', 'model.data.num_workers=0', 'model.data.conv_template=llama_2', 'model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14', 'model.mm_cfg.llm.from_pretrained=null', 'model.use_flash_attention=false', 'exp_manager.exp_dir=./foo-neva-train']
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 51, in <module>
[rank0]:     main()
[rank0]:   File "/home/tfogal/dev/nemo/nemo/core/config/hydra_runner.py", line 129, in wrapper
[rank0]:     _run_hydra(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
[rank0]:     _run_app(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
[rank0]:     run_and_report(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
[rank0]:     raise ex
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
[rank0]:     return func()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
[rank0]:     lambda: hydra.run(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
[rank0]:     _ = ret.return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
[rank0]:     raise self._return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
[rank0]:     ret.return_value = task_function(task_cfg)
[rank0]:   File "/home/tfogal/dev/nemo/./examples/multimodal/multimodal_llm/neva/neva_pretrain.py", line 45, in main
[rank0]:     trainer.fit(model)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
[rank0]:     self._run_sanity_check()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
[rank0]:     val_loop.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
[rank0]:     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
[rank0]:     output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 410, in validation_step
[rank0]:     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
[rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1640, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1456, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
[rank0]:     out = method(*_args, **_kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 897, in validation_step
[rank0]:     return MegatronGPTModel.validation_step(self, dataloader_iter)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1370, in validation_step
[rank0]:     loss = self.fwd_bwd_step(dataloader_iter, True, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 665, in fwd_bwd_step
[rank0]:     return MegatronGPTModel.fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 684, in fwd_bwd_step
[rank0]:     losses_reduced_per_micro_batch = fwd_bwd_function(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 395, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 219, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 832, in fwd_output_and_loss_func
[rank0]:     output_tensor = model(**forward_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/module.py", line 61, in forward
[rank0]:     res = self._forward_fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 683, in fn_
[rank0]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 225, in cache_info_wrapper
[rank0]:     res = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
[rank0]:     jit_results: TraceResults = interpreter(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 213, in _general_frontend
[rank0]:     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
[rank0]:     result = jfn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6769, in fn_
[rank0]:     raise e
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6737, in fn_2
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 470, in forward
[rank0]:     result = GPTModel.forward(self, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py", line 280, in forward
[rank0]:     lm_output = self.language_model(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 764, in forward
[rank0]:     encoder_input = self.embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 348, in forward
[rank0]:     words_embeddings = self.word_embeddings(input_ids)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 155, in forward
[rank0]:     return self.replace_media_embeddings(input_ids, words_embeddings, media)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 195, in replace_media_embeddings
[rank0]:     media_features = self.encode_vision_x(media)  # b T F S(eq) H(idden)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 176, in encode_vision_x
[rank0]:     vision_x = self.vision_encoder(vision_x, output_hidden_states=True)
[rank0]: RuntimeError: Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float16!

Full log of the run

Instructions on how to run NeVA are in #343.

Pitch

This is for the NeVA model #343 .

Alternatives / Potential work-arounds

It seems like our cat checks are too stringent, in that torch allows mismatched dtypes here:

>>> a = torch.randn((5,3), dtype=torch.bfloat16)
>>> b = torch.randn((2,3), dtype=torch.float16)
>>> c = torch.cat((a,b), dim=0)
>>> c
tensor([[ 1.7734,  0.4414, -0.3086],
        [-0.4453, -2.2969, -0.2129],
        [-0.6680, -1.3984, -0.0649],
        [ 0.0242, -0.6875,  0.4277],
        [-0.9141,  0.6367,  0.3828],
        [ 1.0635, -0.4417, -0.6030],
        [ 0.5215, -0.6226,  0.9912]])

I suppose torch semantics are to cast each type to the first type?

Note this is very similar to #750. It seems like the issue in #750 just appeared in cat even though the error was earlier, but now we are finding the issue in cat through some other code.

Minimal Repro

$ cat cat-dtype.py
import torch
import thunder

def foo():
  x = torch.randn((5,3), dtype=torch.bfloat16)
  y = torch.randn((2,3), dtype=torch.float16)
  z = torch.cat((x,y), dim=0)
  return z

foo()
thfoo = thunder.jit(foo)
thfoo()
$ python3 cat-dtype.py
Traceback (most recent call last):
  File "/tmp/cat-dtype.py", line 12, in <module>
    thfoo()
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 683, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 225, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 213, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6769, in fn_
    raise e
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6737, in fn_2
    return fn(*args, **kwargs)
  File "/tmp/cat-dtype.py", line 7, in foo
    z = torch.cat((x,y), dim=0)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1272, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 704, in wrapper
    return fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 276, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 812, in cat
    return clang.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/clang/__init__.py", line 1289, in cat
    return prims.cat(tensors, dim)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 272, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/prims.py", line 2983, in cat_meta
    utils.check_same_dtype(*tensors)
  File "/home/tfogal/dev/thunder/thunder/core/utils.py", line 240, in check_same_dtype
    check(
  File "/home/tfogal/dev/thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
RuntimeError: Expected dtype thunder.dtypes.bfloat16 but found thunder.dtypes.float16!

cc @tfogal

t-vi commented 3 months ago

I suppose torch semantics are to cast each type to the first type?

No, I think they use upcasting (in the above, bf16 and fp16 give fp32).