a-r-j / ProteinWorkshop

Benchmarking framework for protein representation learning. Includes a large number of pre-training and downstream task datasets, models and training/task utilities. (ICLR 2024)
https://proteins.sh/
MIT License
194 stars 16 forks source link

Issues with feature-computations #56

Closed martinaegidius closed 10 months ago

martinaegidius commented 11 months ago

Hi I want to train a protein-model for node-prediction for a school project and find that proteinworkshop seems very promising!

I have a directory of downloaded .pdb-files from alphafold predictions which I want to use for training in my own defined loop, as this is part of the project specification.

I have been wanting to use GearNet, SchNet or DimeNet for the task, but am having issues with figuring out at which point the features for the proteins are computed due to the high level of abstraction in the forward pass of the different models. I am having a hard time keeping track of the flow of configs and data-processing through all the sub-modules. Can you clarify at which point graph-features are computed?

Ideally, I would like to create graphs and features using graphein, create protein-batches and train the models on these. I have been trying to reproduce the approach in the schnet main-call, but am additionally facing issues with tensors being of type float instead of long, shapes being mismatched, etc. (minimal examplifying showcase here).

Is it possible to do this using graphein, or is it necessary to use the ProteinFeaturiser? If so, can I somehow apply the featuriser to pdb-files, a graphein proteinbatch or graph object, or anything along these lines?

Thank you for your help!

a-r-j commented 11 months ago

Hi @martinaegidius! Happy to help :)

I have been wanting to use GearNet, SchNet or DimeNet for the task, but am having issues with figuring out at which point the features for the proteins are computed due to the high level of abstraction in the forward pass of the different models.

If you are using the BenchMarkModel object, node and edge-level features are computed immediately following transfer of the batch to the device . I.e. before the forward pass. If you are using the individual encoder objects you will have to do this manually (as per your notebook).

Can you clarify at which point graph-features are computed?

We do not make use of graph-level featurisation at this point in time.

have been trying to reproduce the approach in the schnet main-call, but am additionally facing issues with tensors being of type float instead of long, shapes being mismatched, etc.

Let me check this out an try to reproduce on my end

Is it possible to do this using graphein, or is it necessary to use the ProteinFeaturiser? If so, can I somehow apply the featuriser to pdb-files, a graphein proteinbatch or graph object, or anything along these lines?

In theory you can use whatever you like as long as you follow the naming conventions. Encoders have a required_batch_attributes property which specifies what needs to be included. The shapes should follow the convention that we use.

If so, can I somehow apply the featuriser to pdb-files, a graphein proteinbatch or graph object, or anything along these lines?

Yes, the featuriser is designed to be used on graphein ProteinBatch objects. You can check this notebook for more details. For using your own data, graphein.protein.tensor.io.protein_to_pyg will be useful to you.

a-r-j commented 11 months ago

Looking at the dtype issue in the block you mentioned:

from proteinworkshop.models.graph_encoders.schnet import SchNetModel
model = SchNetModel() #not using cfg as it
print(model)

#model = SchNetModel()
out = model.forward(batch) #something is float but should be long integer
print(out)

We need to look at batch:

ProteinBatch(fill_value=[1], atom_list=[1], residue_type=[574], batch=[574], ptr=[2], residues=[1], id=[1], edge_index=[2, 4592], x=[574], residue_id=[1], chains=[574], pos=[574, 3], coords=[574, 37, 3])

Here we can see that the batch has not been featurised. batch.x is the default placeholder of integers encoding the residue type. If you try to use the batch featurised by the ProteinFeaturiser, example_batch_f, it should work as expected.

from proteinworkshop.features.factory import ProteinFeaturiser
from proteinworkshop.datasets.utils import create_example_batch
from proteinworkshop.models.graph_encoders.schnet import SchNetModel

example_batch = create_example_batch()
example_batch_f = ca_featuriser(example_batch)

model = SchNetModel() #not using cfg as it
out = model.forward(example_batch_f)
print(out)

As for the shape issue, it looks like it was probably due to initialisation of the lazy layers (they're generally well-behaved but can sometimes throw an error in my experience). However, I managed to run your notebook in colab from scratch. Perhaps try restarting your kernel or using a fresh runtime.

martinaegidius commented 10 months ago

Hi

Thank you so much! This cleared things up, and now everything seems to work.