Closed VinaySingh561 closed 1 month ago
Hmm, I would re-use the optimizer state, but I don't think there is one correct answer. It probably depends on how much the force weighting has changed since your original training. I think this question is technically out of the scope of this project though.
Hi, I am running the e3nn model in JAX, and I have trained it for 100 epochs with a loss function that includes both energy and force. Currently, I am achieving a good R² score for energy. Now, I want to increase the weightage of the force loss in the training process. Therefore, I would like to use the parameters from my last trained model. However, I am unsure about how to initialize the optimizer state.
I have already saved the optimizer state and parameters from the previous model. Should I use the opt_state from the trained model or initialize it using opt.init(params)? I have attached the code below for your convenience.
Thank you for your time and consideration.
`import optax
Define a learning rate scheduler
learning_rate_schedule = optax.exponential_decay( init_value=1e-2, # initial learning rate transition_steps=900, # how often to decay decay_rate=0.9, # the decay rate staircase=True # if True, decay happens at discrete intervals )
Initialize the optimizer with the learning rate schedule
opt = optax.chain( optax.scale_by_adam(), optax.scale_by_schedule(learning_rate_schedule), optax.scale(-1.0) # multiply by -1.0 to perform gradient descent )
opt_state = opt.init(params) with open('best_energy_model_0.9978.pkl', 'rb') as f: checkpoint = pickle.load(f) params = checkpoint['params'] opt_state = checkpoint['opt_state']`