remigenet / TKAN

TKAN: Temporal Kolmogorov-Arnold Networks
Other
133 stars 19 forks source link

Slow training #10

Closed BradleyCornish closed 4 weeks ago

BradleyCornish commented 1 month ago

I am able to create a model but the model training is incredibly slow compared to a LSTM model with a similar number of parameters (800ms/step vs 7ms per step). I am wondering if others have experienced similar issues? I am using a RTX3090 GPU and Tensorflow 2.9 with Python 3.10.

Thanks.

remigenet commented 1 month ago

Hi would it be possible to share the code you used for it and the shape of the input datas ?

remigenet commented 1 month ago

I think it's related to your tensorflow version and jit_compilation options I have seen that something that takes 15ms with an LSTM per steps take around 75ms if jit_compile is not defined in the compile in tf2.17 (default being 'auto'), while if set to False it's between 1 and 2second ! In addition new version of the TKAN will be released soon in keras to supports multiple-backend (already on second branch for now) and I have observed better performances using "jax[cuda12]" than using tensorflow (around 20-25% better)

kunliu916 commented 1 month ago

How to use "jax[cuda12]"?, I an=m extremely slow in training and can hardly train out.

kunliu916 commented 1 month ago

Do you mean that the TKN layer will be directly callable as the Keras API? Approximately when will it be?

remigenet commented 1 month ago

Well I just released it ! What I mean by making it in keras instead of tensorflow do not mean that it will be accessible from the Keras API directly, but that it is made in keras in order to support different backend (torch, jax and tf), still it should work in tensorflow (I have added training times on a 3080Ti in the example section notebook, with comparison to GRU and LSTM). When you use keras, it detects which backend you are using, so if you are using tensorflow it will use tensorflow in the back What you should do here is either use a more recent tensorflow version or use another backend like jax, that means if you code only using keras it's straightforward, but if you have direct tensorflow calls you will need to change them To specify jax as backend you need to have it installed using pip install "jax[cuda12]", to have the version with GPU enable, and then at the beginning of your script do import os BACKEND = 'jax' # You can use any backend here os.environ['KERAS_BACKEND'] = BACKEND

as by default backend is tensorflow ! But tensorflow, if using last version that use jit_compilation, is not that far, I believe it's that your code is run in Eager mode and not being compiled due to the older version

remigenet commented 1 month ago

Just realized from TKAT issue that you seems to be on windows, so not really sure what is compatible for GPU usage on it, maybe a dual boot with linux may help you for you DL projects !