lliuz / ARFlow

The official PyTorch implementation of the paper "Learning by Analogy: Reliable Supervision from Transformations for Unsupervised Optical Flow Estimation".
MIT License
249 stars 49 forks source link

What the corresponding relation between config files and checkpoints? #27

Open Kewenjing1020 opened 3 years ago

Kewenjing1020 commented 3 years ago

Hi, I noticed there're several config files and checkpoints for each dataset. Take KITTI as an example, what's the corresponding relation between config files kitti15_ft_ar.json kitti15_ft.json kitti_raw.json and checkpoints pwclite_ar_mv.tar pwclite_ar.tar pwclite_raw.tar? Which config should I use if I want to reproduce these three models?

Besides, I tried to evaluate the checkpoint pwclite_ar_mv.tar with all the three config files and always got the following error:

[INFO] => using pre-trained weights checkpoints/KITTI15/pwclite_ar_mv.tar. Traceback (most recent call last): File "train.py", line 50, in basic_train.main(cfg, _log) File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/basic_train.py", line 53, in main train_loader, valid_loader, model, loss, _log, cfg.save_root, cfg.train) File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/trainer/kitti_trainer.py", line 13, in init train_loader, valid_loader, model, loss_func, _log, save_root, config) File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/trainer/base_trainer.py", line 26, in init self.model = self._init_model(model) File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/trainer/base_trainer.py", line 75, in _init_model model.load_state_dict(weights) File "/scratch/workspace/wenjingk/anaconda-3.6/envs/python3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for PWCLite: size mismatch for flow_estimators.conv1.0.weight: copying a param with shape torch.Size([128, 198, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 115, 3, 3]). size mismatch for context_networks.convs.0.0.weight: copying a param with shape torch.Size([128, 68, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 34, 3, 3]).

TonyLianLong commented 1 year ago

I guess you need to have 3 frames in the config for the multi-view checkpoint.