pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.29k stars 3.65k forks source link

UserWarning: RNN module weights are not part of single contiguous chunk of memory #7999

Open erikhuck opened 1 year ago

erikhuck commented 1 year ago

🐛 Describe the bug

Using an LSTM jumping knowledge with the GAT model results in a warning when evaluating the model (not when training on it):

/mlab/data/edhu227/.miniconda3/envs/acmpp/lib/python3.10/site-packages/torch/nn/modules/rnn.py:812: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at /opt/conda/conda-bld/pytorch_1682343995026/work/aten/src/ATen/native/cudnn/RNN.cpp:982.)
  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,

It appears that this can be fixed by calling flatten_parameters() on the internal LSTM model.

My model:

class GraphAttentionNetwork(torch.nn.Module):
    def __init__(
            self, device, n_atom_features: int, n_bond_features: int | None, gat_hidden_channels: int, gat_n_layers,
            gat_out_channels: int, gat_act_func: str, gat_norm: str, aggregation: str, gat_dropout: float, n_heads: int,
            leaky_relu_slope: float, mlp_hidden_channels: int, mlp_n_layers: int, mlp_act_func: str, mlp_norm: str, mlp_dropout: float):
        super(GraphAttentionNetwork, self).__init__()
        self.device = device
        self.gat_act_func = tgres.activation_resolver(gat_act_func)
        self.gat = tgnn.GAT(
            in_channels=n_atom_features, hidden_channels=gat_hidden_channels*n_heads, num_layers=gat_n_layers, out_channels=gat_out_channels,
            dropout=gat_dropout, v2=True, act=self.gat_act_func, act_first=False, norm=gat_norm, jk='lstm', heads=n_heads,
            edge_dim=n_bond_features, negative_slope=leaky_relu_slope, add_self_loops=False)
        if aggregation == 'max' or aggregation == 'add' or aggregation == 'mean':
            mlp_in_channels = gat_out_channels
        elif aggregation == 'max-add' or aggregation == 'max-mean' or aggregation == 'add-mean':
            mlp_in_channels = gat_out_channels * 2
        elif aggregation == 'max-add-mean':
            mlp_in_channels = gat_out_channels * 3
        else:
            raise ValueError(
                f'Invalid value for reduce_operation: {aggregation}. Valid values are max, add, mean, max-add, max-mean, add-mean,'
                f' or max-add-mean')
        self.aggregation = aggregation
        self.gat_norm = tgres.normalization_resolver(gat_norm, in_channels=mlp_in_channels)
        self.mlp = tgnn.MLP(
            in_channels=mlp_in_channels, hidden_channels=mlp_hidden_channels, num_layers=mlp_n_layers, out_channels=1, dropout=mlp_dropout,
            act=mlp_act_func, act_first=False, norm=mlp_norm)
        self.to(device)

    def forward(self, data):
        atom_features, bond_indices, bond_features, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.gat(x=atom_features, edge_index=bond_indices, edge_attr=bond_features)
        if self.aggregation == 'max':
            x = tgnn.global_max_pool(x, batch)
        elif self.aggregation == 'add':
            x = tgnn.global_add_pool(x, batch)
        elif self.aggregation == 'mean':
            x = tgnn.global_mean_pool(x, batch)
        elif self.aggregation == 'max-add':
            x = torch.cat([tgnn.global_max_pool(x, batch), tgnn.global_add_pool(x, batch)])
        elif self.aggregation == 'max-mean':
            x = torch.cat([tgnn.global_max_pool(x, batch), tgnn.global_mean_pool(x, batch)])
        elif self.aggregation == 'add-mean':
            x = torch.cat([tgnn.global_add_pool(x, batch), tgnn.global_mean_pool(x, batch)])
        elif self.aggregation == 'max-add-mean':
            x = torch.cat([tgnn.global_max_pool(x, batch), tgnn.global_add_pool(x, batch), tgnn.global_mean_pool(x, batch)], dim=1)
        else:
            raise ValueError(f'Invalid value for self.reduce_operation: {self.aggregation}.')
        x = self.gat_norm(x)
        x = self.gat_act_func(x)
        x = self.mlp(x)
        return torch.sigmoid(x).flatten()

Prediction code (ran after training for a number of epochs):

def predicting(model, device, loader):
    model.eval()
    total_preds = torch.Tensor()
    total_labels = torch.Tensor()
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)

Environment

erikhuck commented 1 year ago

And here's the code for instantiating the model:

        model = GraphAttentionNetwork(
            device=device, n_atom_features=dataset.n_atom_features, n_bond_features=dataset.n_bond_features, gat_hidden_channels=64,
            gat_n_layers=4, gat_out_channels=256, gat_act_func='relu', gat_norm='batch_norm', aggregation='max-add-mean',
            gat_dropout=0.1, n_heads=10, leaky_relu_slope=0.2, mlp_hidden_channels=128, mlp_n_layers=3, mlp_act_func='relu',
            mlp_norm='batch_norm', mlp_dropout=0.1)