open-mmlab / mmpretrain

OpenMMLab Pre-training Toolbox and Benchmark
https://mmpretrain.readthedocs.io/en/latest/
Apache License 2.0
3.49k stars 1.08k forks source link

[Bug] Error when trying to replicate the tutorial for pre-training with custom dataset. #1919

Open siddhi-wiai opened 4 months ago

siddhi-wiai commented 4 months ago

Branch

main branch (mmpretrain version)

Describe the bug

I have followed the exact same steps as mentioned in the tutorial for pre-training MAE on a custom dataset, but getting the following error:

File "/home/XXX/code_siddhi/mmpretrain/mmpretrain/models/utils/data_preprocessor.py", line 261, in _input[:, [2, 1, 0], ...] for _input in batch_inputs TypeError: string indices must be integers

Environment

{'sys.platform': 'linux', 'Python': '3.8.19 (default, Mar 20 2024, 19:58:24) [GCC 11.2.0]', 'CUDA available': True, 'MUSA available': False, 'numpy_random_seed': 2147483648, 'GPU 0': 'NVIDIA L4', 'CUDA_HOME': '/usr/local/cuda', 'NVCC': 'Cuda compilation tools, release 12.3, V12.3.107', 'GCC': 'gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0', 'PyTorch': '1.10.1', 'TorchVision': '0.11.2', 'OpenCV': '4.10.0', 'MMEngine': '0.10.4', 'MMCV': '2.0.1', 'MMPreTrain': '1.2.0+17a886c'}

Other information

keiohta commented 4 months ago

Hi @siddhi-wiai , I encountered the same issue and found a solution as follows:

  1. specify train_pipeline (just copied from the base dataset)
  2. add _delete_=True to avoid an error
# >>>>>>>>>>>>>>> Override dataset settings here >>>>>>>>>>>>>>>>>>>
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='RandomResizedCrop',
        scale=224,
        crop_ratio_range=(0.2, 1.0),
        backend='pillow',
        interpolation='bicubic'),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackInputs')
]
train_dataloader = dict(
    batch_size=128,
    dataset=dict(
        type='CustomDataset',
        data_root='data/custom_dataset/',
        ann_file='',  # We assume you are using the sub-folder format without ann_file
        data_prefix='',  # The `data_root` is the data_prefix directly.
        with_label=False,
        _delete_=True,  # Need to remove `split` keyword
        pipeline=train_pipeline  # Need to specify pipeline
    )
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<