jdtoscano94 / Learning-Scientific_Machine_Learning_Residual_Based_Attention_PINNs_DeepONets

Physics Informed Machine Learning Tutorials (Pytorch and Jax)
436 stars 164 forks source link

Issue with using your code in Colab because of JAX! #4

Open AmirhosseinnnKhademi opened 1 year ago

AmirhosseinnnKhademi commented 1 year ago

Hi Juan, Thanks for your helpful video. I subscribed! However, I have a problem here! the "from jax.experimental import optimizers" does not work for me! it says "cannot import name 'optimizers' from 'jax.experimental'". Then I have to switch to CPU and also install "!pip install jax[cpu]==0.2.27" to work! It is confusing for me and as I searched the net for other people. Could you please let me know how you use GPU in your video? I have to use my CPU and it takes like a year for training!!!

Thank you

jdtoscano94 commented 1 year ago

Hi, thanks for reporting this issue. I used the google colab GPU. However, it looks like they changed something in the new updates, which does not work now. We only use "jax. experimental" to get our optimizer. I would recommend finding a way to use a previous version of Colab or importing the jax. experimental library. If not, as an alternative, you can try optax, but you may need to add some minor changes downstream too. To install optax you would use the following commands: !pip install optax import optax Then you can create an optimizer using the following lines: optimizer = optax.adam(optax.exponential_decay(lr0, decay_step, decay_rate,)) opt_state = optimizer.init(params) To update our parameters you need to compute the gradients of your loss. You can do something like this: de los_fn(params): '"your loss function here" return loss grad_fn = value_and_grad(loss_fn,) loss, grads = grad_fn(params) updates, opt_state = optimizer.update(grads, opt_state, params) params optax.apply_updates(params, updates)

Thanks for letting me know and I am sorry I could not help much this time.