a-r-j / graphein

Protein Graph Library
https://graphein.ai/
MIT License
1.01k stars 126 forks source link

esm_residue_embedding not working #274

Closed avivko closed 1 year ago

avivko commented 1 year ago

Describe the bug 1)gp.esm_residue_embedding yields a the an error (see error stack 1 bellow) during graph construction 2) Using the following partial on this functionpartial(gp.esm_residue_embedding, model_name="esm2_t33_650M_UR50D") also yields an error (see error stack 2).

To Reproduce Run this Colab: https://colab.research.google.com/drive/1M2N9ZFS7WGdnyNBSo--n8l9hBwEbO0qn?usp=sharing

Expected behavior gp.esm_residue_embedding should work and should also work with a partial specifying the model.

Screenshots Error stack 1:

Using cache found in ./torch_home/hub/facebookresearch_esm_main
[03/09/23 02:13:32] DEBUG    Deprotonating protein. This removes H atoms from   [graphs.py](file:///usr/local/lib/python3.9/dist-packages/graphein/protein/graphs.py):176
                             the pdb_df dataframe                                            
                    DEBUG    Detected 1365 total nodes                          [graphs.py](file:///usr/local/lib/python3.9/dist-packages/graphein/protein/graphs.py):383
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-9-764532a0e598>](https://localhost:8080/#) in <module>
     17                              ])
     18 
---> 19 g = construct_graph(config=graphein_config, pdb_code="6rew")
     20 
     21 p = plotly_protein_structure_graph(

2 frames
[/usr/local/lib/python3.9/dist-packages/graphein/protein/graphs.py](https://localhost:8080/#) in construct_graph(config, name, pdb_path, uniprot_id, pdb_code, df, chain_selection, model_index, df_processing_funcs, edge_construction_funcs, edge_annotation_funcs, node_annotation_funcs, graph_annotation_funcs, verbose)
    788         # Annotate additional node metadata
    789         if config.node_metadata_functions is not None:
--> 790             g = annotate_node_metadata(g, config.node_metadata_functions)
    791 
    792         if verbose:

[/usr/local/lib/python3.9/dist-packages/graphein/utils/utils.py](https://localhost:8080/#) in annotate_node_metadata(G, funcs)
    110     for func in funcs:
    111         for n, d in G.nodes(data=True):
--> 112             func(n, d)
    113     return G
    114 

[/usr/local/lib/python3.9/dist-packages/graphein/protein/features/sequence/embeddings.py](https://localhost:8080/#) in esm_residue_embedding(G, model_name, output_layer)
    196     """
    197 
--> 198     for chain in G.graph["chain_ids"]:
    199         embedding = compute_esm_embedding(
    200             G.graph[f"sequence_{chain}"],

AttributeError: 'str' object has no attribute 'graph'

Error stack 2:

Using cache found in ./torch_home/hub/facebookresearch_esm_main
                    DEBUG    Detected 1365 total nodes                          [graphs.py](file:///usr/local/lib/python3.9/dist-packages/graphein/protein/graphs.py):383
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-10-d3141a5de7c5>](https://localhost:8080/#) in <module>
     17                              ])
     18 
---> 19 g = construct_graph(config=graphein_config, pdb_code="6rew")
     20 
     21 p = plotly_protein_structure_graph(

1 frames
[/usr/local/lib/python3.9/dist-packages/graphein/protein/graphs.py](https://localhost:8080/#) in construct_graph(config, name, pdb_path, uniprot_id, pdb_code, df, chain_selection, model_index, df_processing_funcs, edge_construction_funcs, edge_annotation_funcs, node_annotation_funcs, graph_annotation_funcs, verbose)
    788         # Annotate additional node metadata
    789         if config.node_metadata_functions is not None:
--> 790             g = annotate_node_metadata(g, config.node_metadata_functions)
    791 
    792         if verbose:

[/usr/local/lib/python3.9/dist-packages/graphein/utils/utils.py](https://localhost:8080/#) in annotate_node_metadata(G, funcs)
    110     for func in funcs:
    111         for n, d in G.nodes(data=True):
--> 112             func(n, d)
    113     return G
    114 

TypeError: esm_residue_embedding() got multiple values for argument 'model_name'

Desktop (please complete the following information):

a-r-j commented 1 year ago

Hey @avivko

Could you try with:

graphein_config = gp.ProteinGraphConfig(
    graph_metadata_functions=[
                             gp.esm_residue_embedding
                             partial(gp.esm_residue_embedding, model_name=model)
                             ])

I tried it in your colab and it seems to work (thanks for including that!).

It’s unintuitive (and ugly API-wise) but the per-residue node embeddings need access to the whole sequence in order to compute them, so this function doesn’t quack like typical node metadata functions which depend exclusively on information contained within nodes.

The underlying function should still distribute the embeddings to the nodes

a-r-j commented 1 year ago

How did you get on @avivko? Can I close the issue?

I've tried to catch the error with a more informative message in 1.6.0 (Now resolved in 1.6.0 (pip install graphein=1.6.0)

avivko commented 1 year ago

Sorry for the delay in my response. Your suggestion worked, thanks!