HarryShomer / LPFormer

Implementation of the KDD'24 paper "LPFormer: An Adaptive Graph Transformer for Link Prediction"
https://arxiv.org/abs/2310.11009
9 stars 3 forks source link

device setting issue #2

Closed boxorange closed 3 months ago

boxorange commented 3 months ago

Hi,

When I set device to a different number (e.g., 0~3 in case of 4 GPUs) such as --device 1 other than --device 0, I get this error. Do you have any thoughts on what might be causing it? Does some part of the code need to be modified?

Traceback (most recent call last):
  File "/home/ac.gpark/LPFormer/src/run.py", line 328, in <module>
    main()
  File "/home/ac.gpark/LPFormer/src/run.py", line 324, in main
    run_model(args)
  File "/home/ac.gpark/LPFormer/src/run.py", line 243, in run_model
    train_data(cmd_args, args, data, device, verbose = not cmd_args.non_verbose)
  File "/home/ac.gpark/LPFormer/src/train/train_model.py", line 197, in train_data
    best_valid = train_loop(args, train_args, data, device, loggers, seed, run_save_name, verbose)
  File "/home/ac.gpark/LPFormer/src/train/train_model.py", line 111, in train_loop
    loss = train_epoch(model, score_func, data, optimizer, args, device)
  File "/home/ac.gpark/LPFormer/src/train/train_model.py", line 59, in train_epoch
    h = model(edges, adj_prop=masked_adjt, adj_mask=masked_adj)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ac.gpark/LPFormer/src/models/link_transformer.py", line 116, in forward
    X_node = self.propagate(adj_prop, test_set)
  File "/home/ac.gpark/LPFormer/src/models/link_transformer.py", line 142, in propagate
    X_node = self.node_encoder(x, adj, test_set)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ac.gpark/LPFormer/src/modules/node_encoder.py", line 42, in forward
    X_gnn = self.gnn_encoder(features, adj_t)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ac.gpark/LPFormer/src/models/other_models.py", line 66, in forward
    xi = conv(x, adj_t)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 233, in forward
    edge_index = gcn_norm(  # yapf: disable
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 56, in gcn_norm
    adj_t = torch_sparse.fill_diag(adj_t, fill_value)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch_sparse/diag.py", line 92, in fill_diag
    return set_diag(src, value.new_full(sizes, fill_value), k)
  File "/home/ac.gpark/anaconda3/envs/ness/lib/python3.9/site-packages/torch_sparse/diag.py", line 49, in set_diag
    new_row[mask] = row
RuntimeError: shape mismatch: value tensor of shape [2483110] cannot be broadcast to indexing result of shape [0]
HarryShomer commented 3 months ago

Closing since it's mentioned in #1. We'll deal with it there.