PJLab-ADG / 3DTrans

An open-source codebase for exploring autonomous driving pre-training
https://bobrown.github.io/Team_3DTrans.github.io/
Apache License 2.0
585 stars 72 forks source link

Pretraining on NuScenes #21

Closed aleksaabo1 closed 9 months ago

aleksaabo1 commented 10 months ago

I'm currently working on pretraining the AD-PT model using the NuScenes dataset, but I've hit a few roadblocks and could really use some help. Here's where I'm at:

Following the guide i have aquired the:

(However, due to an error these files had to be moved from 3DTrans/data/nuscenes/v1.0-trainval/, to 3DTrans /data/nuscenes/)

When running the script:

sh scripts/PRETRAIN/dist_train_pointcontrast.sh 2 --cfg_file cfgs/nuscenes_models/cbgs_dyn_pp_centerpoint.yaml --batch_size 4 --epochs 30

I received this error:

  File "../pcdet/datasets/nuscenes/nuscenes_semi_dataset.py", line 114, in split_nuscenes_semi_data
    raw_split = data_splits['raw']
KeyError: 'raw'

As I see the data_splits are : {'train': 'train', 'test': 'test'}

After figuring this out i modified the code (File: nuscenes_semi_dataset.py) to only run if data_splits contians 'raw':


raw_split = data_splits.get('raw')
        if raw_split:
            for info_path in info_paths[raw_split]:
                if oss_path is None:
                    info_path = root_path / info_path
                    with open(info_path, 'rb') as f:
                        infos = pickle.load(f)
                        nuscenes_unlabeled_infos.extend(copy.deepcopy(infos))
                else:
                    info_path = os.path.join(oss_path, info_path)
                    pkl_bytes = client.get(info_path, update_cache=True)
                    infos = pickle.load(io.BytesIO(pkl_bytes))
                    nuscenes_unlabeled_infos.extend(copy.deepcopy(infos))

Doing this removed the error. However, then i received this error:

Traceback (most recent call last):
  File "train_pointcontrast.py", line 206, in <module>
    main()
  File "train_pointcontrast.py", line 112, in main
    datasets, dataloaders, samplers = build_unsupervised_dataloader(
  File "../pcdet/datasets/__init__.py", line 301, in build_unsupervised_dataloader
    unlabeled_dataset = _semi_dataset_dict[dataset_cfg.DATASET]['UNLABELED_PAIR'](
KeyError: 'UNLABELED_PAIR'



Looking into this error I saw that this key is not in the NuScenes key, as_semi_dataset_dict looked like this:


_semi_dataset_dict = {
    'ONCEDataset': {
        'PARTITION_FUNC': split_once_semi_data,
        'PRETRAIN': ONCEPretrainDataset,
        'LABELED': ONCELabeledDataset,
        'UNLABELED': ONCEUnlabeledDataset,
        'UNLABELED_PAIR': ONCEUnlabeledPairDataset,
        'TEST': ONCETestDataset
    },
    'NuScenesDataset': {
        'PARTITION_FUNC': split_nuscenes_semi_data,
        'PRETRAIN': NuScenesPretrainDataset,
        'LABELED': NuScenesLabeledDataset,
        'UNLABELED': NuScenesUnlabeledDataset,
        'TEST': NuScenesTestDataset
    },
    'KittiDataset': {
        'PARTITION_FUNC': split_kitti_semi_data,
        'PRETRAIN': KittiPretrainDataset,
        'LABELED': KittiLabeledDataset,
        'UNLABELED': KittiUnlabeledDataset,
        'TEST': KittiTestDataset
    }
}

I then added a condition where the code in init.py only ran if 'UNLABELED_PAIR' was in the dataset(file: pcdet/datasets/init.py):

if 'UNLABELED_PAIR' in _semi_dataset_dict[dataset_cfg.DATASET]:
        unlabeled_dataset = _semi_dataset_dict[dataset_cfg.DATASET]['UNLABELED_PAIR'](
            dataset_cfg=dataset_cfg,
            class_names=class_names,
            infos = unlabeled_infos,
            root_path=root_path,
            logger=logger,
        )

Then this happened:

Traceback (most recent call last):
  File "train_pointcontrast.py", line 206, in <module>
2023-11-10 13:42:20,443 nuscenes_semi_dataset.py split_nuscenes_semi_data 130  INFO  Total samples for nuscenes testing dataset: 0
2023-11-10 13:42:20,443 nuscenes_semi_dataset.py split_nuscenes_semi_data 131  INFO  Total samples for nuscenes labeled dataset: 0
2023-11-10 13:42:20,443 nuscenes_semi_dataset.py split_nuscenes_semi_data 132  INFO  Total samples for nuscenes unlabeled dataset: 0
Traceback (most recent call last):
  File "train_pointcontrast.py", line 206, in <module>
    main()
      File "train_pointcontrast.py", line 112, in main
main()
  File "train_pointcontrast.py", line 112, in main
    datasets, dataloaders, samplers = build_unsupervised_dataloader(
  File "../pcdet/datasets/__init__.py", line 312, in build_unsupervised_dataloader
    datasets, dataloaders, samplers = build_unsupervised_dataloader(
  File "../pcdet/datasets/__init__.py", line 312, in build_unsupervised_dataloader
    unlabeled_sampler = torch.utils.data.distributed.DistributedSampler(unlabeled_dataset)
UnboundLocalError: local variable 'unlabeled_dataset' referenced before assignment
    unlabeled_sampler = torch.utils.data.distributed.DistributedSampler(unlabeled_dataset)
UnboundLocalError: local variable 'unlabeled_dataset' referenced before assignment

Any ideas on how to tackle these errors?

BOBrown commented 10 months ago

@aleksaabo1

Hi, please note that employing the scripts/PRETRAIN/dist_train_pointcontrast.sh is not pre-training using our proposed AD-PT, but for pre-training using the PointContrast.

We will release the AD-PT pre-training here ASAP.