melloddy / SparseChem

Fast and accurate machine learning models for biochemical applications.
MIT License
53 stars 11 forks source link

Predicting last_hidden #4

Closed wheyndri closed 2 years ago

wheyndri commented 2 years ago

In predict.pythere is an option to output the last hidden layer instead of the predictive probabilities: parser.add_argument("--last_hidden", help="If set to 1 returns last hidden layer instead of Yhat", type=int, default=0) However, from the documentation it is not immediately clear if this returns either: last layer of the shared trunk or the last layer of everything, so when hidden in the head is used, this would mean the layer from the private head.

Probably the last layer of the shared trunk would be most useful since it can be seen as a common representation of compounds.

In addition @NJSturm reports that the last hidden prediction with a hybrid model requires defining from which arm do to so: classification and/or regression , and maybe we would prefer to use the last hidden in the trunk.

NJSturm commented 2 years ago

@molden : I support the request made by @wheyndri of having the possibility to get the trunk embeddings in addition to having the embeddings at the last hidden (which could be from the private head).

I did also notice that the Error on L297 of models.py is raised when trying to predict the last hidden layer with the function predict_hidden() together with a simple classification or a simple regression model.
It seems like it is assumed that hybrid model is in use with catalog fusion is disabled - and consequently the last hidden layer prediction is not supported (see L288).

Would it be possible to support last hidden and trunk embeddings if cat_fusion = 0 ?
For the time being checking if the model is a hybrid model could be done by looking at the classification/regression output sizes...

molden commented 2 years ago

Would it be possible to quickly test the branch https://github.com/melloddy/SparseChem/tree/4-predicting-last_hidden . The option is now called --trunk_embeddings instead of --last_hidden

wheyndri commented 2 years ago

@molden Thanks for looking into this. I tested for a reg and hyb melloddy model (both cat_fusion =0): No errors to report. Dimensions and values look reasonable.

NJSturm commented 2 years ago

@molden : thanks this did run through on my side as well for a classification and a hybrid model ( catalog fusion = 0 on both).

@Fabien-GELUS : this is worthwhile having a look into for a potential integration in MELLODDY-Predictor (outputs the embeddings of compounds in latent representation of last layer in shared trunk). That relates to one feature request I made a while back .