unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.01k stars 871 forks source link

Mixed precision support for Torch Models #2484

Open Joelius300 opened 2 months ago

Joelius300 commented 2 months ago

Is your feature request related to a current problem? Please describe.

Currently, the only valid precisions when training a torch model are 32-true and 64-true (see here). Although it seems to be possible to load weights for models with lower precision, namely float16 since #2046, I didn't find a way to train one from scratch using these lower precisions. Given limited hardware resources, it would be awesome to speed up training with bf16-mixed, 16-mixed or any of the others supported by Lightning.

Describe proposed solution

It would be nice to be able to specify the precision like now but with more options available (all that work I guess).

Describe potential alternatives

I have not found an alternative besides not training through darts.

Additional context

I'm sure there's a reason only 32 and 64 bit are supported but I didn't find much in the docs or other issues. The disclaimer was written in #1651 including tests that 16 bit is not supported but I was not able to figure out what exactly is the issue with e.g. 16-mixed.

I also found a number of issues mentioning precision, e.g. #2344, #1987, but none of them seem to request supporting 16 bit still, so I thought I'd open this issue; apologies if I missed something relevant.

Especially #1987 seems to have talked about it, but then it was closed because the bug that reset the precision after loading was fixed. There's also the old #860, which was not revisited IIRC, but it mentions that it would break functionality for Scalers.

MichaelVerdegaal commented 4 days ago

Also interested in this