kaist-amsg / LocalRetro

Retrosynthesis prediction for organic molecules with LocalRetro
82 stars 24 forks source link

Does the switch to GELU break backward compatibility? #19

Closed kmaziarz closed 11 months ago

kmaziarz commented 1 year ago

A recent commit changed the activation function from ReLU to GELU. Wouldn't this potentially break previously trained checkpoints (i.e. if training was done with ReLU, while inference was done with the newest codebase that already has GELU)?

shuan4638 commented 1 year ago

The current model in the models/ directory was trained with GeLUactivation function. Of course if one trained the model with previous code, the activation should be changed back to ReLU.

kmaziarz commented 1 year ago

Yes, I meant existing old checkpoints, independently trained by users of the codebase. I think it might be a bit confusing if somebody has trained their own LocalRetro model, and suddenly the model behaviour changes slightly after pulling the newest code. Ideally, GeLU would be the default for new models, while old models would continue to use ReLU they were trained with.

shuan4638 commented 1 year ago

Thanks for the comments. Now I understand it is confusing for people who trained their own model before and pulled the newest code. I was not awaring of this issue and I should have changed the branch when I made the changes.

Is there any way to resolve this issue at this stage?

kmaziarz commented 1 year ago

Maybe the choice of attention function could be added as another argument to LocalRetro_model.__init__, set through the *.json config like the other architectural hyperparameters? Old configs would lack the appropriate field specifying the activation function, but then the code loading the config (e.g. get_configure) would make sure to interpret the missing value as ReLU.

kmaziarz commented 1 year ago

I've now taken a checkpoint trained with the old code and verified the following:

shuan4638 commented 1 year ago

Maybe the choice of attention function could be added as another argument to LocalRetro_model.init, set through the *.json config like the other architectural hyperparameters? Old configs would lack the appropriate field specifying the activation function, but then the code loading the config (e.g. get_configure) would make sure to interpret the missing value as ReLU.

This is the good suggestion: it will cause an error if activation flag does not exist in the config file, asking the user to add the activation function back.

I added this feature in my latest commit and passed a simple test. Thanks!

kmaziarz commented 1 year ago

Thanks! The change looks good, although I notice that your earlier GELU commit changed the activation function not only in the model, but also in the FeedForward layer used within the Global_Reactivity_Attention module. So perhaps the activation determined in LocalRetro_model.__init__ should be also passed down into Global_Reactivity_Attention and from there into FeedForward.

shuan4638 commented 1 year ago

That's a great look, I didn't notice that. I also updated model.py and model_utils.py so the activation flag in the config will be passed to both files. Thanks a lot.

kmaziarz commented 1 year ago

Looks good!

One more thought: as you said, currently using an old checkpoint will fail due to the activation key missing from the config. This may confuse some users, as they may not even know which activation function their model used during training (pretty much the only way to figure that out is to analyse the recent git commits or see this issue thread). Maybe it would make sense to replace exp_config['activation'] with exp_config.get('activation', 'relu'), so that old checkpoints work correctly without modification? New checkpoints would still use GELU by default, as the default config provided in the repo specifies that; likely the only situation where the activation key would be missing is when one trained a model using old code and then pulled in the newest changes.

shuan4638 commented 1 year ago

Great solution. I didn't even know this function of python dictionary. I have updated this as well, hope everything works fine.

kmaziarz commented 12 months ago

I'm taking a further look at what the recent code changes were, and I came across these changes to get_edit_site. Am I parsing this correctly that the previous code was essentially returning all pairs of nodes as the second return value (also those not connected by a bond), whereas the new code only includes the pairs connected by bonds?

shuan4638 commented 12 months ago

the previous code was essentially returning all pairs of nodes

If you are talking about the line 41-49, those scripts only get the connected bond becasue I only iterate the bonds connected to the atom, so the other list only contains neighboring atoms.

I changed this code to fit the order of default dgl edges so I can modify the pair_atom_feats function in model_utils to resolve the dgl version issue (#15)

kmaziarz commented 12 months ago

Ah, sorry, I misread atom.GetBonds() in the previous code as mol.GetBonds(). Makes sense then.