Closed DouglasOrr closed 3 months ago
This fixes a regression in the auto-conversion of GPT in the out-of-the-box-fp8-training notebook.
out-of-the-box-fp8-training
Current state:
mod = type("Custom", (nn.Module,), {})() mod.child = type("CustomChild", (nn.Module,), {})() mod.child.lin = nn.Linear(10, 20) print(unit_scale(mod)) # Custom( # (child): CustomChild( # (lin): Linear(in_features=10, out_features=20, bias=True) # ) # (child.lin): trivial_subclass_modules_linear_Linear(in_features=10, out_features=20, bias=True) # )
In torch_nn_modules_to_user_modules(), since setattr(mod, n, newsubmod) happily sets a non-valid-identifier key n such as "child.lin" without recursing, we get new child modules rather than replacing the existing ones.
torch_nn_modules_to_user_modules()
setattr(mod, n, newsubmod)
n
"child.lin"
This patch:
print(unit_scale(mod)) # Custom( # (child): CustomChild( # (lin): trivial_subclass_modules_linear_Linear(in_features=10, out_features=20, bias=True) # ) # )
(There may still be some issues, e.g. when module instances are shared by different parents - after conversion they could become un-shared.)
This fixes a regression in the auto-conversion of GPT in the
out-of-the-box-fp8-training
notebook.Current state:
In
torch_nn_modules_to_user_modules()
, sincesetattr(mod, n, newsubmod)
happily sets a non-valid-identifier keyn
such as"child.lin"
without recursing, we get new child modules rather than replacing the existing ones.This patch:
(There may still be some issues, e.g. when module instances are shared by different parents - after conversion they could become un-shared.)