Open dreavjr opened 1 year ago
I don't see why it might not be compatible right away, but I haven't tested it.
What happens to the coordinate check if you rerun one of our examples after torch.compile()?
Recently, torch.compile() started using FakeTensors for both input and weight during compilation. That means that temporary FakeTensor weights are created from original Tensor weights. infshape attributes are not copied to these FakeTensor weights.
Consequently, during compilation, MuReadout.forward() and MuReadout.width_mult() trip this assert and the compilation fails.
This unwanted sideeffect will also influence the ability to eg. export mup models to ONNX.
Any advice how to circumvent missing infshapes on FakeTensors going forward?
Is mup compatible with torch.compile() in Pytorch 2? If yes, what is the correct usage (e.g. should we apply mup before compile or after)?