Open Karlheinzniebuhr opened 2 years ago
It doesn't look like you're specifying the solver, or the tolerances/step-sizes at all. The defaults may not be suitable for your data.
You may like to try Diffrax, which is built for JAX instead. This builds on a lot of the lessons we learnt building torchcde/torchsde/torchdiffeq, and in particular has much better default behaviour, that demands that you make an explicit choice about this kind of thing.
I adapted your time_series_classification example for market data prediction. It seems to be working but training is exceptionally slow on a P100 GPU which normally finishes similar tasks in 30m. After 4 hours it completed the first 2 epochs. Is this normal with CDEs or did I do something wrong? Training loss is also diverging, but that might be due to learning rate I haven't checked that yet. Here is the dataprep function I added as well as some minor adaptations to the model.
The complete code with corresponding data CSV: time_series_prediction example