molML / MoleculeACE

A tool for evaluating the predictive performance on activity cliff compounds of machine learning models
MIT License
164 stars 19 forks source link

Unable to run README example #3

Closed PatWalters closed 1 year ago

PatWalters commented 1 year ago

I've been unable to run the example. It doesn't seem possible to directly reproduce the environment you used, and I'm getting an exception when I try to run your code using an environment I created with.

conda create -n moleculeACE python=3.8
conda activate moleculeACE
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
pip install tensorflow
conda install pyg -c pyg
pip install transformers

When I try to run the README example. I get an exception on model.train(data.x_train, data.y_train)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 1
----> 1 model.train(data.x_train, data.y_train)

File ~/software/MoleculeACE/MoleculeACE/models/utils.py:82, in GNN.train(self, x_train, y_train, x_val, y_val, early_stopping_patience, epochs, print_every_n)
     78     break
     80 # As long as the model is still improving, continue training
     81 else:
---> 82     loss = self._one_epoch(train_loader)
     83     self.train_losses.append(loss)
     85     val_loss = 0

File ~/software/MoleculeACE/MoleculeACE/models/utils.py:119, in GNN._one_epoch(self, train_loader)
    116 self.optimizer.zero_grad()
    118 # Forward pass
--> 119 y_hat = self.model(batch.x.float(), batch.edge_index, batch.edge_attr.float(), batch.batch)
    121 # Calculating the loss and gradients
    122 loss = self.loss_fn(squeeze_if_needed(y_hat), squeeze_if_needed(batch.y))

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/software/MoleculeACE/MoleculeACE/models/mpnn.py:104, in MPNNmodel.forward(self, x, edge_index, edge_attr, batch)
    101     node_feats = node_feats.squeeze(0)
    103 # perform global pooling using a multiset transformer to get graph-wise hidden embeddings
--> 104 out = self.transformer(node_feats, batch, edge_index)
    106 # Apply a fully connected layer.
    107 for k in range(len(self.fc)):

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/aggr/base.py:131, in Aggregation.__call__(self, x, index, ptr, dim_size, dim, **kwargs)
    126         if index.numel() > 0 and dim_size <= int(index.max()):
    127             raise ValueError(f"Encountered invalid 'dim_size' (got "
    128                              f"'{dim_size}' but expected "
    129                              f">= '{int(index.max()) + 1}')")
--> 131 return super().__call__(x, index, ptr, dim_size, dim, **kwargs)

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/aggr/gmt.py:245, in GraphMultisetTransformer.forward(self, x, index, ptr, dim_size, dim, edge_index)
    243 for i, (name, pool) in enumerate(zip(self.pool_sequences, self.pools)):
    244     graph = (x, edge_index, index) if name == 'GMPool_G' else None
--> 245     batch_x = pool(batch_x, graph, mask)
    246     mask = None
    248 return self.lin2(batch_x.squeeze(1))

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/aggr/gmt.py:133, in PMA.forward(self, x, graph, mask)
    127 def forward(
    128     self,
    129     x: Tensor,
    130     graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None,
    131     mask: Optional[Tensor] = None,
    132 ) -> Tensor:
--> 133     return self.mab(self.S.repeat(x.size(0), 1, 1), x, graph, mask)

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/aggr/gmt.py:59, in MAB.forward(self, Q, K, graph, mask)
     57 if graph is not None:
     58     x, edge_index, batch = graph
---> 59     K, V = self.layer_k(x, edge_index), self.layer_v(x, edge_index)
     60     K, _ = to_dense_batch(K, batch)
     61     V, _ = to_dense_batch(V, batch)

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/conv/gcn_conv.py:198, in GCNConv.forward(self, x, edge_index, edge_weight)
    195 x = self.lin(x)
    197 # propagate_type: (x: Tensor, edge_weight: OptTensor)
--> 198 out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
    199                      size=None)
    201 if self.bias is not None:
    202     out = out + self.bias

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:392, in MessagePassing.propagate(self, edge_index, size, **kwargs)
    389     if res is not None:
    390         edge_index, size, kwargs = res
--> 392 size = self.__check_input__(edge_index, size)
    394 # Run "fused" message and aggregation (if applicable).
    395 if is_sparse(edge_index) and self.fuse and not self.explain:

File ~/anaconda3/envs/moleculeACE/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:216, in MessagePassing.__check_input__(self, edge_index, size)
    213         the_size[1] = size[1]
    214     return the_size
--> 216 raise ValueError(
    217     ('`MessagePassing.propagate` only supports integer tensors of '
    218      'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
    219      '`torch.sparse.Tensor` for argument `edge_index`.'))

ValueError: `MessagePassing.propagate` only supports integer tensors of shape `[2, num_messages]`, `torch_sparse.SparseTensor` or `torch.sparse.Tensor` for argument `edge_index`.
dangraysf commented 1 year ago

Hey Pat -- I had the same issue; for MPNN.py, GAT.py and GCN.py try this in the forward pass:

out = self.transformer(x=node_feats, index=batch, edge_index=edge_index)

This is for mpnn.py; GAT and GCN are slightly different.

What is going on here is explicitly naming the args.

At least for my venv --

print(torch.__version__) #1.13.1+cu117 print(torch_geometric.__version__) #2.2.0

GraphMultisetTransformer may be deprecated and explicit naming of args makes a difference.

More can be found here -- cheers! https://github.com/pyg-team/pytorch_geometric/issues/3443

PatWalters commented 1 year ago

Thanks! I'll give this a try.