Closed tomschelsen closed 3 months ago
Yes, compiling the model is somehow not compatible. I tried it on the day pytorch 2.0 was released and after a bit of trial and error I managed to compile the model. The biggest obstacle was the use of weight norm. However after the model was compiled, the forward pass produced unexpected results and basically just nonsense outputs. I'm not sure where exactly the problem lies, but I was frustrated and haven't used the compile feature since then. I was playing around with using jit to compile some components of the model, which works ok, but can sometimes also cause unexpected behaviour. Overall I decided to not pursue these compiling features further for now, because I have to manage my priorities and this would require a full re-write of many modules to at least partially compile most of the TTS.
More for information (in case anyone wanted to try it) than a bug, as I think the
torch.compile()
situation isn't fully stabilised yet, I tried to use this new facility of Pytorch 2 on IMS-Toucan, with the initial goal of benchmarking different compiler backends (default "inductor" versus the more recently released 3rd-party "hidet") and different parameters.I tried to compile
ToucanTTSInterface
as it is the higher leveltorch.nn.Module
that I am currently using for inference.I first encountered the following (I edited the backtrace to remove non-relevant paths and calls infos) :
Looking at Pytorch's issues, it seems that currently (Pytorch 2.0.1)
torch.compile()
doesn't always play well withinference_mode()
(but they are working on it). So I went on and blindly replaced all theinference_mode()
contexts/decorators withno_grad()
(the point again being in the end to benchmark and assess if performance is better than the starting point).As a result I got the following :
And... this goes way beyond my understanding of how all of that works. So if anyone wants to give it a shot ;)