xingyizhou / CenterTrack

Simultaneous object detection and tracking using center points.
MIT License
2.37k stars 526 forks source link

How to include validation in training process on KITTI dataset? #196

Open zhangchuang-zc opened 3 years ago

zhangchuang-zc commented 3 years ago

Hi! It's a nice work!! I want to include validation in training process on KITTI dataset, and i use the following code to run main.py python main.py tracking --exp_id kitti_half --dataset kitti_tracking --dataset_version train_half --pre_hm --same_aug --hm_disturb 0.05 --lost_disturb 0.2 --fp_disturb 0.1 --gpus 0 --batch_size 16 --trainval --eval_val --load_model ../models/nuScenes_3Ddetection_e140.pth

But i got the error as following:

File "main.py", line 101, in main(opt) File "main.py", line 79, in main log_dict_val, preds = trainer.val(epoch, val_loader) File "CenterTrack/src/lib/trainer.py", line 314, in val return self.run_epoch('val', epoch, data_loader) File "CenterTrack/src/lib/trainer.py", line 141, in run_epoch for iter_id, batch in enumerate(data_loader): File "lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in next data = self._next_data() File "lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data return self._process_data(data) File "lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data data.reraise() File "lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise raise self.exc_type(msg) TypeError: Caught TypeError in DataLoader worker process 0. Original Traceback (most recent call last): File "python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop data = fetcher.fetch(index) File "lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "CenterTrack/src/lib/dataset/generic_dataset.py", line 114, in getitem c, s, width, height, disturb=True) File "CenterTrack/src/lib/dataset/generic_dataset.py", line 276, in _get_aug_param c[0] += s np.clip(np.random.randn()cf, -2cf, 2cf) TypeError: can't multiply sequence by non-int of type 'numpy.float64'

I wonder if I can train model on KITTI dataset including validation? @xingyizhou THANKS!

ocetintas commented 3 years ago

I had the same issue. I must admit that I didn't analyze the function in depth but the following "hacky" changes solved the problem for me:

      if type(s) == float:
        s = [s, s]
        c[0] += s[0] * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
        c[1] += s[1] * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
      else:
        c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
        c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
      aug_s = np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)