graphcore-research / unit-scaling

A library for unit scaling in PyTorch
https://graphcore-research.github.io/unit-scaling/
Apache License 2.0
104 stars 7 forks source link

Fix recursion in torch_nn_modules_to_user_modules() #57

Closed DouglasOrr closed 3 months ago

DouglasOrr commented 3 months ago

This fixes a regression in the auto-conversion of GPT in the out-of-the-box-fp8-training notebook.

Current state:

mod = type("Custom", (nn.Module,), {})()
mod.child = type("CustomChild", (nn.Module,), {})()
mod.child.lin = nn.Linear(10, 20)

print(unit_scale(mod))
# Custom(
#   (child): CustomChild(
#     (lin): Linear(in_features=10, out_features=20, bias=True)
#   )
#   (child.lin): trivial_subclass_modules_linear_Linear(in_features=10, out_features=20, bias=True)
# )

In torch_nn_modules_to_user_modules(), since setattr(mod, n, newsubmod) happily sets a non-valid-identifier key n such as "child.lin" without recursing, we get new child modules rather than replacing the existing ones.

This patch:

print(unit_scale(mod))
# Custom(
#   (child): CustomChild(
#     (lin): trivial_subclass_modules_linear_Linear(in_features=10, out_features=20, bias=True)
#   )
# )

(There may still be some issues, e.g. when module instances are shared by different parents - after conversion they could become un-shared.)