e3nn / e3nn-jax

jax library for E3 Equivariant Neural Networks
Apache License 2.0
182 stars 18 forks source link

optimzer state during e3nn model training #81

Closed VinaySingh561 closed 1 month ago

VinaySingh561 commented 1 month ago

Hi, I am running the e3nn model in JAX, and I have trained it for 100 epochs with a loss function that includes both energy and force. Currently, I am achieving a good R² score for energy. Now, I want to increase the weightage of the force loss in the training process. Therefore, I would like to use the parameters from my last trained model. However, I am unsure about how to initialize the optimizer state.

I have already saved the optimizer state and parameters from the previous model. Should I use the opt_state from the trained model or initialize it using opt.init(params)? I have attached the code below for your convenience.

Thank you for your time and consideration.

`import optax

Define a learning rate scheduler

learning_rate_schedule = optax.exponential_decay( init_value=1e-2, # initial learning rate transition_steps=900, # how often to decay decay_rate=0.9, # the decay rate staircase=True # if True, decay happens at discrete intervals )

Initialize the optimizer with the learning rate schedule

opt = optax.chain( optax.scale_by_adam(), optax.scale_by_schedule(learning_rate_schedule), optax.scale(-1.0) # multiply by -1.0 to perform gradient descent )

opt_state = opt.init(params) with open('best_energy_model_0.9978.pkl', 'rb') as f: checkpoint = pickle.load(f) params = checkpoint['params'] opt_state = checkpoint['opt_state']`

ameya98 commented 1 month ago

Hmm, I would re-use the optimizer state, but I don't think there is one correct answer. It probably depends on how much the force weighting has changed since your original training. I think this question is technically out of the scope of this project though.