EGO4D / episodic-memory

MIT License
102 stars 60 forks source link

Bug report: cannot load checkpoints for EgoTracks #51

Open vineetparikh opened 10 months ago

vineetparikh commented 10 months ago

Is there any demo code to load the provided checkpoints to a STARK Tracker? I'd like to try the EgoSTARK setup on a different egocentric dataset to see how well it would perform.

vineetparikh commented 9 months ago

Followup from here, it looks like actually loading the provided checkpoints is bugged due to an import error

import torch
from argparse import Namespace
from tracking.utils.defaults import setup
import tracking.models.stark_tracker.stark_tracker as stark

stark_args = Namespace(model_type="STARK", config_file="egotracks_models/stark_st_base.yaml", opts=[])
stark_cfg = setup(stark_args)
tracker = stark.STARKTracker(stark_cfg)

works fine (albeit with some extra config parameters, I'm getting stuff like the checkpoint contains additional key for a lot of layer4 keys along with the two fc keys) using the provided base config, but as soon as I try to do

with open("/path/to/STARKST_ep0001.pth.tar", "rb") as f:
    tracker.model.load_state_dict(torch.load(f)["state_dict"], strict=True)

I get the error

... 
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/home/vap43/.conda/envs/hos/lib/python3.9/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/home/vap43/.conda/envs/hos/lib/python3.9/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/home/vap43/.conda/envs/hos/lib/python3.9/site-packages/torch/serialization.py", line 875, in find_class
    return super().find_class(mod_name, name)
  File "/home/vap43/.conda/envs/hos/lib/python3.9/site-packages/egotracks-1.0-py3.9.egg/tracking/tools/trainers/base_trainer.py", line 16, in <module>
    from tracking.dataset.build import build_dataloaders
  File "/home/vap43/.conda/envs/hos/lib/python3.9/site-packages/egotracks-1.0-py3.9.egg/tracking/dataset/build.py", line 12, in <module>
    from tracking.dataset.train_datasets.ego4d_vq import Ego4DVQ
  File "/home/vap43/.conda/envs/hos/lib/python3.9/site-packages/egotracks-1.0-py3.9.egg/tracking/dataset/train_datasets/ego4d_vq.py", line 8, in <module>
    from tracking.dataset.ego4d_tracking import Ego4DTracking
ModuleNotFoundError: No module named 'tracking.dataset.ego4d_tracking'

which seems to indicate that there's another ego4d_tracking data file that isn't public, and/or that one file was renamed incorrectly. Any insight on what the right file should be here?

shengyuhao commented 8 months ago

I also have this error and hope the authors provide more details or explanations. Thanks!

relh commented 7 months ago

Looks like the only other file that has a similar function signature and also has a response_tracks variable is this one: https://github.com/EGO4D/episodic-memory/blob/5275b9570a3d77884e8685bad474309f7778db1f/EgoTracks/tracking/dataset/eval_datasets/ego4d_lt_tracking_dataset.py#L45

SharkyK commented 5 months ago

I also have this problem. Could you please tell me how to do if you solve this error? Thanks!