jchengai / forecast-mae

[ICCV'2023] Forecast-MAE: Self-supervised Pre-training for Motion Forecasting with Masked Autoencoders
https://arxiv.org/pdf/2308.09882.pdf
154 stars 16 forks source link

multiagent checkpoint/scores problem #16

Open caiocj1 opened 3 months ago

caiocj1 commented 3 months ago

Hello! Thank you for your work.

Just a question, when I try evaluating the multi agent baseline, I get the error:

  File "/shared/home/sf08116/workspace/forecast-mae/eval.py", line 19, in main
    model = Model.load_from_checkpoint(checkpoint)
  File "/shared/home/sf08116/workspace/forecast-mae/eval.py", line 39, in <module>
    main()
RuntimeError: Error(s) in loading state_dict for Trainer:
    Unexpected key(s) in state_dict: "net.hist_embed.levels.0.blocks.0.attn.rpb", "net.hist_embed.levels.0.blocks.1.attn.rpb", "net.hist_embed.levels.1.blocks.0.attn.rpb", "net.hist_embed.levels.1.blocks.1.attn.rpb", "net.hist_embed.levels.2.blocks.0.attn.rpb", "net.hist_embed.levels.2.blocks.1.attn.rpb". 

Adding strict=False when loading the checkpoint bypasses the problem but the metrics are slightly higher than reported:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric      ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       val/cls_loss        │    1.7821619510650635     │
│         val/loss          │    2.1131792068481445     │
│       val/reg_loss        │    0.33101630210876465    │
│        val_ActorMR        │    0.1966301053762436     │
│       val_AvgMinADE       │     0.733747661113739     │
│       val_AvgMinFDE       │    1.6657145023345947     │
└───────────────────────────┴───────────────────────────┘

But the main problem is scoring, the best score is always given to modality 1 (thus the high classification loss). Is this because of the strict=False? Do you also always get modality 1 as the best in the scores?