rs-station / careless

Merge X-ray diffraction data with Wilson's priors, variational inference, and metadata
MIT License
16 stars 6 forks source link

Optimal number of cycles #84

Closed gyuhyeokcho closed 1 year ago

gyuhyeokcho commented 1 year ago

What is the optimal number of cycles? I found these variables were changing during the run: loss, F KLDiv, and NLL. Which one is the best indicator?

kmdalton commented 1 year ago

Hi @gyuhyeokcho,

To clarify,

There is no clear consensus on the optimal number of optimization steps at this time. Typically, we have just run training to convergence with ~30k iterations. However, I can currently give you no guarantee that this strategy is best. I would recommend using crossvalidation as a principled approach to select the number of optimization steps.

You can designate a fraction of your data to be held out during optimization with the --test-fraction flag. For instance, --test-fraction=0.05 will reserve 5% of the reflections to the test set. This procedure is analogous to the "Free" set in structure refinement used to compute Rfree.

If specified, test reflections will be marked in the *_predictions_#.mtz files which careless writes at the end of training. You can use the careless.ccpred on the predictions file to compare the correlation between predictions and observations for the test and training fractions. A principled way to decide the number of optimization steps would be to choose the setting with the smallest gap between correlation coefficients for train and test. The reasoning here is that the model is less over fit with respect to a particular subset of the data as it generalizes better to previously unseen data. This is essentially what we did in Figure 3 of the paper for choosing the degrees of freedom of the loss function.

Additionally, I will note that careless records the values of the F KLDiv, NLL, and Loss in a *_history.csv file. This plain text file may be of interest to you as well. It can be useful to plot the values as a function of the iterations using seaborn. For example:

import seaborn as sns
import pandas as pd

df = pd.read_csv("hewl_history.csv")
sns.lineplot(df.melt('step'), x='step', y='value', hue='variable')
plt.semilogy()

As I mentioned, we have not yet done an exhaustive study of the number of optimization steps. Instead, we have usually optimized until convergence and varied other parameters of the model about which we have more intuition. I make no claim that this is the best strategy. There are a great many ways to train the model. We're very interested in user feedback in order to begin assembling a set of best practices. So, if you discover anything of interest in your experiments. We'd be very happy to hear any insights you uncover through your experiments. Feel free to contact myself or @DHekstra in whatever manner you feel comfortable — either privately or in a public GitHub issue.

Thanks for your excellent question!

Kevin

gyuhyeokcho commented 1 year ago

Thank you for your answer!