microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.37k stars 94 forks source link

Usage with torch.compile in Pytorch 2? #60

Open dreavjr opened 1 year ago

dreavjr commented 1 year ago

Is mup compatible with torch.compile() in Pytorch 2? If yes, what is the correct usage (e.g. should we apply mup before compile or after)?

edwardjhu commented 11 months ago

I don't see why it might not be compatible right away, but I haven't tested it.

What happens to the coordinate check if you rerun one of our examples after torch.compile()?

tivek commented 7 months ago

Recently, torch.compile() started using FakeTensors for both input and weight during compilation. That means that temporary FakeTensor weights are created from original Tensor weights. infshape attributes are not copied to these FakeTensor weights.

Consequently, during compilation, MuReadout.forward() and MuReadout.width_mult() trip this assert and the compilation fails.

This unwanted sideeffect will also influence the ability to eg. export mup models to ONNX.

Any advice how to circumvent missing infshapes on FakeTensors going forward?