z-x-yang / Segment-and-Track-Anything

An open-source project dedicated to tracking and segmenting any objects in videos, either automatically or interactively. The primary algorithms utilized include the Segment Anything Model (SAM) for key-frame segmentation and Associating Objects with Transformers (AOT) for efficient tracking and propagation purposes.
GNU Affero General Public License v3.0
2.75k stars 332 forks source link

How to train aot? Where is aot's train.py? #135

Open 22236 opened 5 months ago

22236 commented 5 months ago

I want to train aot before I think about deaot.

22236 commented 5 months ago

Use GPU 0 for training VOS. Build VOS model. Use Frozen BN in Encoder! Build optimizer. Total Param: 5.73M Process dataset... Video Num: 29 X 1 <class 'numpy.ndarray'> (512, 512, 3) (512, 512, 3) <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> Video Num: 29 X 1 Done! Remove ['features.0.1.num_batches_tracked', 'features.1.conv.0.1.num_batches_tracked', 'features.1.conv.2.num_batches_tracked', 'features.2.conv.0.1.num_batches_tracked', 'features.2.conv.1.1.num_batches_tracked', 'features.2.conv.3.num_batches_tracked', 'features.3.conv.0.1.num_batches_tracked', 'features.3.conv.1.1.num_batches_tracked', 'features.3.conv.3.num_batches_tracked', 'features.4.conv.0.1.num_batches_tracked', 'features.4.conv.1.1.num_batches_tracked', 'features.4.conv.3.num_batches_tracked', 'features.5.conv.0.1.num_batches_tracked', 'features.5.conv.1.1.num_batches_tracked', 'features.5.conv.3.num_batches_tracked', 'features.6.conv.0.1.num_batches_tracked', 'features.6.conv.1.1.num_batches_tracked', 'features.6.conv.3.num_batches_tracked', 'features.7.conv.0.1.num_batches_tracked', 'features.7.conv.1.1.num_batches_tracked', 'features.7.conv.3.num_batches_tracked', 'features.8.conv.0.1.num_batches_tracked', 'features.8.conv.1.1.num_batches_tracked', 'features.8.conv.3.num_batches_tracked', 'features.9.conv.0.1.num_batches_tracked', 'features.9.conv.1.1.num_batches_tracked', 'features.9.conv.3.num_batches_tracked', 'features.10.conv.0.1.num_batches_tracked', 'features.10.conv.1.1.num_batches_tracked', 'features.10.conv.3.num_batches_tracked', 'features.11.conv.0.1.num_batches_tracked', 'features.11.conv.1.1.num_batches_tracked', 'features.11.conv.3.num_batches_tracked', 'features.12.conv.0.1.num_batches_tracked', 'features.12.conv.1.1.num_batches_tracked', 'features.12.conv.3.num_batches_tracked', 'features.13.conv.0.1.num_batches_tracked', 'features.13.conv.1.1.num_batches_tracked', 'features.13.conv.3.num_batches_tracked', 'features.14.conv.0.1.num_batches_tracked', 'features.14.conv.1.1.num_batches_tracked', 'features.14.conv.3.num_batches_tracked', 'features.15.conv.0.1.num_batches_tracked', 'features.15.conv.1.1.num_batches_tracked', 'features.15.conv.3.num_batches_tracked', 'features.16.conv.0.1.num_batches_tracked', 'features.16.conv.1.1.num_batches_tracked', 'features.16.conv.3.num_batches_tracked', 'features.17.conv.0.1.num_batches_tracked', 'features.17.conv.1.1.num_batches_tracked', 'features.17.conv.3.num_batches_tracked', 'features.18.1.num_batches_tracked', 'classifier.1.weight', 'classifier.1.bias'] from pretrained model. Load pretrained backbone model from Segment-and-Track-Anything-main/aot/pretrain_models/mobilenet_v2-b0353104.pth. Start training: Traceback (most recent call last): File "Segment-and-Track-Anything-main/my_code/aot/train.py", line 111, in main() # 执行主函数 File "Segment-and-Track-Anything-main/my_code/aot/train.py", line 108, in main main_worker(0, cfg, args.amp) # 单进程训练 File "Segment-and-Track-Anything-main/my_code/aot/train.py", line 52, in main_worker trainer.sequential_training() File "Segment-and-Track-Anything-main/./aot/networks/managers/trainer.py", line 419, in sequential_training for sample in enumerate(train_loader): File "anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in next data = self._next_data() File "anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data return self._process_data(data) File "anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data data.reraise() File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/_utils.py", line 434, in reraise raise exception KeyError: Caught KeyError in DataLoader worker process 0. Original Traceback (most recent call last): File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/anaconda3/envs/samtr/lib/python3.9/site-packages/torch/utils/data/dataset.py", line 308, in getitem return self.datasets[dataset_idx][sample_idx] KeyError: 4

The above is my error, I don't know how to modify it. When I was training my dataset, I got KeyError, I didn't know what it meant, and when I changed the len or obj parameter in default.py, sometimes the KeyError was 6, sometimes 0, sometimes 2, sometimes 4. I would like to know what the idx, max_obj_n, seq_len etc parameters refer to in the code.

Thanks for the answer.