Closed dabianzhixing closed 1 year ago
Have you tried to load the weights to cpu first?
net = torch.load(pthfile, map_location="cpu")
Thank you very much. The weight could be loaded on the CPU. However, I still can't get the embedding of the protein. My code is as follows:
device = torch.device("cpu")
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5), geometry.KNNEdge(k=10, min_distance=5), geometry.SequentialEdge(max_distance=2)], edge_feature="gearnet")
protein = data.Protein.from_pdb('4k0e_C.pdb', atom_feature="position", bond_feature="length", residue_feature="symbol") protein = data.Protein.pack([protein]) protein = graph_construction_model(_protein)
gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, edge_input_dim=59, num_angle_bin=8, batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")
gearnet_edge.load_state_dict(torch.load('mc_gearnet_edge.pth', map_location=device), strict=True) gearnet_edge.eval()
print(protein_.nodefeature.float().shape) print(protein)
with torch.no_grad(): output = gearnetedge(protein, protein_.node_feature.float(), all_loss=None, metric=None) print(output)
I get an error like this:
Traceback (most recent call last):
File "GearNet/MC_gearnet_edge.py", line 45, in
It seems that the "GeometricRelationalGraphConv" want to convert the feature into a specific shape, but the conversion failed. what should I do?
Hi, I guess that it's because you forget to transform the protein in to residue
view so that we will pass residue_feature
instead of atom_feature
as node_feature
. In the tutorial, we define the transform at the beginning and use it for processsing batches. In your case, where you only want to process one protein, you can simply set protein_.view = residue
.
Yes, you are right, now I can get the embedding of the proteins. Thank you very much!
When I want to load the weight like this,
net = torch.load(pthfile)
I got this error:
RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain.
I don't know what happened. After Google, it seems that I need to update my CUDA driver. Are there any other options?