Closed tfogal closed 3 months ago
Staring down the stacktrace in the middle of the night, this seems to be part of MegaTron's TensorParallel applied to embeddings and the question would be what our expectation is (up to now, I think we have tried to run models without their own parallelization and then set up Thunder's; we might also try running through it with tensorparallel enabled, but then maybe the more natural thing to divert is reduce_from_tensor_model_parallel_region
or the layer class's forwad method).
[rank0]: File "/home/tfogal/dev/nemo/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py", line 162, in forward
[rank0]: words_embeddings = super().forward(input_ids, **kwargs)
[rank0]: File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
[rank0]: return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]: File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py", line 245, in forward
[rank0]: output = reduce_from_tensor_model_parallel_region(output_parallel)
[rank0]: File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/mappings.py", line 446, in reduce_from_tensor_model_parallel_region
[rank0]: return _ReduceFromModelParallelRegion.apply(input_)
As @crcrpar points out, further down the stack, this is called by autograd.Function:
[rank0]: File "/home/tfogal/dev/pytorch/torch/autograd/function.py", line 573, in apply
[rank0]: args = _functorch.utils.unwrap_dead_wrappers(args)
so we might want to allow tracing through the autograd.Function.apply and forward (and hoping to do the backward with standard autodiff / registrations of our own). Or we might treat forward and backward as black-box units and include them in the trace as symbols, which would be staying to the user code. Or we could try to make it our own thing with the thunder derivative transform mechnism.
🐛 Bug
First install the
tfogal/thunder-nemo
branch from https://github.com/tfogal/NeMo. Then run:After ~15--20 seconds this produces:
Environment
Additional context
unwrap_if_dead
appears to be part of functorch.cc @tfogal