Open tfogal opened 4 months ago
At first glance, I would probably try making a lookaside for make_viewless_tensor
that is a no-op. I don't think we want to model _base
and probably even less so for trying to force it to be None in likely unforseen ways to mess with the internals of PyTorch.
That said, I wonder if a nicer error message for _base
would also be nice.
At first glance, I would probably try making a lookaside for
make_viewless_tensor
that is a no-op.
I agree with this. I'm not a fan of the method in question.
make_viewless_tensor
does. avoiding garbage collection?make_viewless_tensor
from megatron a no-op?Marking high priority as it turns out the 'correct' configuration of NeVA actually runs into this (earlier we avoided this problem by running NeVA in a different mode).
The developer who wrote this is out but I had a great chat with @jaredcasper who was wicked helpful:
deallocate_output_tensor
, which enables megatron to delete the tensor's data w/o deleting the tensor itselfdeallocate_pipeline_output
, which would mean it doesn't even use this optimization.foo.new_tensor(foo.data_ptr())
. The easiest is probably to just clone()
the tensor. This will have a negative impact on memory use but should get us running.
Thanks, Jared! Marking triage review to make sure we discuss.
Playing with this, are you sure that .detach()
is not enough (in recent versions of PyTorch at least)?
import torch
print(torch.__version__)
a = torch.randn(100, 100) # 2.4.0+cu124
b = a.view(100, 100)
assert b._base is a
c = a.detach()
assert c._base is None
assert c.data_ptr() == a.data_ptr()
So I would expect this to be equivalent to the .data
business for _kernel_make_viewless_tensor
def _kernel_make_viewless_tensor(inp, requires_grad):
return inp.detach().requires_grad_(requires_grad)
Because of how we currently fudge torch.autograd.Function, the detach will not be good for Thunder, so I would start with this:
@thunder.core.jit_ext.register_general_jit_lookaside(
megatron.core.utils.make_viewless_tensor
)
@thunder.core.jit_ext.interpreter_needs_wrap
def make_viewless_tensor_lookaside(inp, requires_grad, keep_graph)
return inp
for now given that we will be running this in tracing, i.e. with TensorProxies. @tfogal Could you give this a spin please?
@tfogal @t-vi I gave it a go and the WAR seems to work, however the run still crashes with a type mismatch later on. To me the type mismatch does not seem related to this workaround but I cannot be certain 100% atm. The next error is:
AssertionError: Data types for parameters must match when outside of autocasted region. Found input dtype: thunder.dtypes.float32 and 'layer_norm_weight' dtype: thunder.dtypes.bfloat16
And it comes from File Megatron-LM/megatron/core/transformer/transformer_block.py", line 398, in forward hidden_states, context = layer(
for the layer defined as a megatron.core.transformer.transformer_layer.TransformerLayer
Could you give this a spin
I gave it a go and
hah, love when people jump on things while I'm asleep :-)
Riccardo, it's not clear whether the workaround you refer to was TomV's latest reply (of changing megatron to use detach), or of TomV's earlier thought of creating a no-op lookaside for make_viewless_tensor
. If it's the latter, would it make sense as a thunder PR?
The next error is:
I'd agree with you that this is likely unrelated. If you could get an issue filed that would be helpful, but a PR would be more important so that other people could even reach that error.
Riccardo, it's not clear whether the workaround you refer to was TomV's latest reply (of changing megatron to use detach), or of TomV's earlier thought of creating a no-op lookaside for make_viewless_tensor. If it's the latter, would it make sense as a thunder PR?
@tfogal sorry for the fast write up š so the WAR I tried is the no-op! I'll open an issue with the repro and the latest error
You can repro the error in this issue by using the instructions in #1044 and commenting out the lookaside WAR
Riccardo, it's not clear whether the workaround you refer to was TomV's latest reply (of changing megatron to use detach), or of TomV's earlier thought of creating a no-op lookaside for make_viewless_tensor.
@tfogal sorry for the fast write up š so the WAR I tried is the no-op! I'll open an issue with the repro and the latest error
No need to apologize! When we're all back from free days could you send a PR with the no-op lookaside patch you have?
š Model / language coverage
Running Megatron GPT from NeMo, we seem to have issues with this line from Megatron core. Some context from the caller in this particular case might be illuminating.
Pitch
This is coming from a finetuning case of the Megatron GPT network. To setup and run the network, see #344.
Alternatives / Potential work-arounds
Marking triage review for help
Minimal Repro
(help wanted)
cc @apaz-cli @tfogal