rampasek / GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer
MIT License
643 stars 114 forks source link

error when running ogbg-ppa #20

Closed Chen-Cai-OSU closed 1 year ago

Chen-Cai-OSU commented 1 year ago

Hello, Thank you very much for the great code. It is super well-written and very modularized, making it easy to extend. I encounter a error when running ogbg-ppa with python main.py --cfg configs/GPS/ogbg-ppa-GPS.yaml wandb.use False gt.layer_type CustomGatedGCN+Performer

I got many lines of

accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'
accuracy() missing 1 required positional argument: 'task'

followed by

  File "/code/GraphGPS/main.py", line 174, in <module>
    train_dict[cfg.train.mode](https://github.com/rampasek/GraphGPS/issues/loggers,%20loaders,%20model,%20optimizer,%0A%20%20File%20%22/code/GraphGPS/graphgps/train/custom_train.py%22,%20line%20122,%20in%20custom_train%0A%20%20%20%20perf%5B0%5D.append(loggers%5B0%5D.write_epoch(cur_epoch))
  File "/code/GraphGPS/graphgps/logger.py", line 245, in write_epoch
    task_stats = self.classification_multilabel()
  File "/code/GraphGPS/graphgps/logger.py", line 146, in classification_multilabel
    'accuracy': reformat(acc(pred_score, true)),
  File "/code/GraphGPS/graphgps/metric_wrapper.py", line 323, in call
    return self.compute(preds, target)
  File "/code/GraphGPS/graphgps/metric_wrapper.py", line 310, in compute
    metric_val = torch.nanmean(torch.stack(metric_val))  # PyTorch1.10
RuntimeError: stack expects a non-empty TensorList

I was wondering can you identify the source of error. I am using torchmetrics=0.11.0. and torch=1.10.2. Let me know if you need any further information. Thank you!

rampasek commented 1 year ago

Hi!

This is caused by a change of API in torchmetrics starting from version 0.11. You can either downgrade in your environment to torchmetrics=0.10.x or move to updated version of GraphGPS -> see PR #22

Best, Ladislav