Closed tfogal closed 1 month ago
The key bits in the traceback are
[rank0]: File "/home/tfogal/env/lib/python3.10/site-packages/apex/normalization/fused_layer_norm.py", line 227, in mixed_dtype_fused_rms_norm_affine
[rank0]: return FusedRMSNormAffineMixedDtypesFunction.apply(*args)
and
[rank0]: File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 930, in _general_jit_torch_autograd_function_apply_lookaside
[rank0]: return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_)
The problem is that we should not use tree map here, but just pass the list and dictionary to not descend into the tuple. I'll send a PR.
🚀 Model / language coverage
Full log of the failing run
Pitch
This comes up when trying to support NeVA on the pure-thunder path (i.e. no dynamo frontend).
Alternatives / Potential work-arounds
We could just use the dynamo frontend, for now.
Minimal Repro
Still working on this...
cc @tfogal