ssnl / dataset-distillation

Open-source code for paper "Dataset Distillation"
https://ssnl.github.io/dataset_distillation
MIT License
778 stars 115 forks source link

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 4035 and 4003 in dimension 2 at /opt/conda/conda-bld/pytorch_1544199946412/work/aten/src/TH/generic/THTensorMoreMath.cpp:1333 #59

Closed shraddha291996 closed 1 year ago

shraddha291996 commented 1 year ago

Hello,

I am using AlexNet and custom dataset with rgb images , I am cropping them to 32 by 32 size. And i get the following error. Could you please help me with this issue ?

(new-distil) [la00204104@wgs.wuerth.com@cn1401lx000001 dataset-distillation]$ python main.py --mode distill_basic --dataset data_custom --arch AlexCifarNet --distill_lr 0.001 --batch_size 1  --num_workers 1 --test_batch_size 2 --test_optimize_n_nets 1
INFO:root:Logging to ./results/distill_basic/data_custom/arch(AlexCifarNet,xavier,1.0)_distillLR0.001_E(400,40,0.5)_lr0.01_B1x10x3_train(unknown_init)/output.log
WARNING:root:Log file already exists, will append
2023-01-07 15:45:19 [INFO ]  ======================================== 2023-01-07 15:45:19 ========================================
2023-01-07 15:45:19 [INFO ]  Base directory is ./results/distill_basic/data_custom/arch(AlexCifarNet,xavier,1.0)_distillLR0.001_E(400,40,0.5)_lr0.01_B1x10x3_train(unknown_init)
2023-01-07 15:45:19 [INFO ]  Options:
2023-01-07 15:45:19 [INFO ]     arch: AlexCifarNet
2023-01-07 15:45:19 [INFO ]     attack_class: 0
2023-01-07 15:45:19 [INFO ]     base_seed: 1
2023-01-07 15:45:19 [INFO ]     batch_size: 1
2023-01-07 15:45:19 [INFO ]     checkpoint_interval: 10
2023-01-07 15:45:19 [INFO ]     dataset: data_custom
2023-01-07 15:45:19 [INFO ]     dataset_labels: !!python/tuple
2023-01-07 15:45:19 [INFO ]     - class1
2023-01-07 15:45:19 [INFO ]     - class2
2023-01-07 15:45:19 [INFO ]     - class3
2023-01-07 15:45:19 [INFO ]     - class4
2023-01-07 15:45:19 [INFO ]     dataset_normalization: !!python/tuple
2023-01-07 15:45:19 [INFO ]     - !!python/tuple
2023-01-07 15:45:19 [INFO ]         - 0.4914
2023-01-07 15:45:19 [INFO ]         - 0.4822
2023-01-07 15:45:19 [INFO ]         - 0.4465
2023-01-07 15:45:19 [INFO ]     - !!python/tuple
2023-01-07 15:45:19 [INFO ]         - 0.247
2023-01-07 15:45:19 [INFO ]         - 0.243
2023-01-07 15:45:19 [INFO ]         - 0.261
2023-01-07 15:45:19 [INFO ]     dataset_root: /home/wgs.wuerth.com/la00204104/dataset-distillation/data/dataset
2023-01-07 15:45:19 [INFO ]     decay_epochs: 40
2023-01-07 15:45:19 [INFO ]     decay_factor: 0.5
2023-01-07 15:45:19 [INFO ]     device_id: 0
2023-01-07 15:45:19 [INFO ]     distill_epochs: 3
2023-01-07 15:45:19 [INFO ]     distill_lr: 0.001
2023-01-07 15:45:19 [INFO ]     distill_steps: 10
2023-01-07 15:45:19 [INFO ]     distilled_images_per_class_per_step: 1
2023-01-07 15:45:19 [INFO ]     distributed: false
2023-01-07 15:45:19 [INFO ]     dropout: false
2023-01-07 15:45:19 [INFO ]     epochs: 400
2023-01-07 15:45:19 [INFO ]     expr_name_format: null
2023-01-07 15:45:19 [INFO ]     image_dpi: 80
2023-01-07 15:45:19 [INFO ]     init: xavier
2023-01-07 15:45:19 [INFO ]     init_param: 1.0
2023-01-07 15:45:19 [INFO ]     input_size: 32
2023-01-07 15:45:19 [INFO ]     log_file: ./results/distill_basic/data_custom/arch(AlexCifarNet,xavier,1.0)_distillLR0.001_E(400,40,0.5)_lr0.01_B1x10x3_train(unknown_init)/output.log
2023-01-07 15:45:19 [INFO ]     log_interval: 100
2023-01-07 15:45:19 [INFO ]     log_level: INFO
2023-01-07 15:45:19 [INFO ]     lr: 0.01
2023-01-07 15:45:19 [INFO ]     mode: distill_basic
2023-01-07 15:45:19 [INFO ]     model_dir: ./models/
2023-01-07 15:45:19 [INFO ]     model_subdir_format: null
2023-01-07 15:45:19 [INFO ]     n_nets: 1
2023-01-07 15:45:19 [INFO ]     nc: 3
2023-01-07 15:45:19 [INFO ]     no_log: false
2023-01-07 15:45:19 [INFO ]     num_classes: 4
2023-01-07 15:45:19 [INFO ]     num_workers: 1
2023-01-07 15:45:19 [INFO ]     phase: train
2023-01-07 15:45:19 [INFO ]     results_dir: ./results/
2023-01-07 15:45:19 [INFO ]     sample_n_nets: 1
2023-01-07 15:45:19 [INFO ]     source_dataset: null
2023-01-07 15:45:19 [INFO ]     start_time: '2023-01-07 15:45:19'
2023-01-07 15:45:19 [INFO ]     target_class: 1
2023-01-07 15:45:19 [INFO ]     test_batch_size: 2
2023-01-07 15:45:19 [INFO ]     test_distill_epochs: null
2023-01-07 15:45:19 [INFO ]     test_distilled_images: loaded
2023-01-07 15:45:19 [INFO ]     test_distilled_lrs:
2023-01-07 15:45:19 [INFO ]     - loaded
2023-01-07 15:45:19 [INFO ]     test_n_nets: 1
2023-01-07 15:45:19 [INFO ]     test_n_runs: 1
2023-01-07 15:45:19 [INFO ]     test_name_format: null
2023-01-07 15:45:19 [INFO ]     test_nets_type: unknown_init
2023-01-07 15:45:19 [INFO ]     test_niter: 1
2023-01-07 15:45:19 [INFO ]     test_optimize_n_nets: 1
2023-01-07 15:45:19 [INFO ]     test_optimize_n_runs: null
2023-01-07 15:45:19 [INFO ]     train_nets_type: unknown_init
2023-01-07 15:45:19 [INFO ]     world_rank: 0
2023-01-07 15:45:19 [INFO ]     world_size: 1
2023-01-07 15:45:19 [INFO ]
2023-01-07 15:45:19 [WARNING]  ./results/distill_basic/data_custom/arch(AlexCifarNet,xavier,1.0)_distillLR0.001_E(400,40,0.5)_lr0.01_B1x10x3_train(unknown_init)/opt.yaml already exists, moved to ./results/distill_basic/data_custom/arch(AlexCifarNet,xavier,1.0)_distillLR0.001_E(400,40,0.5)_lr0.01_B1x10x3_train(unknown_init)/old_opts/opt_2023_01_07__15_20_59.yaml
2023-01-07 15:45:29 [INFO ]  train dataset size:        178
2023-01-07 15:45:29 [INFO ]  test dataset size:         178
2023-01-07 15:45:29 [INFO ]  datasets built!
2023-01-07 15:45:29 [INFO ]  mode: distill_basic, phase: train
2023-01-07 15:45:29 [INFO ]  Build 1 AlexCifarNet network(s) with [xavier(1.0)] init
2023-01-07 15:45:33 [INFO ]  Build 1 AlexCifarNet network(s) with [xavier(1.0)] init
2023-01-07 15:45:33 [INFO ]  Train 10 steps iterated for 3 epochs
2023-01-07 15:45:33 [INFO ]  Results saved to ./results/distill_basic/data_custom/arch(AlexCifarNet,xavier,1.0)_distillLR0.001_E(400,40,0.5)_lr0.01_B1x10x3_train(unknown_init)/checkpoints/epoch0000/results.pth
2023-01-07 15:45:33 [INFO ]
2023-01-07 15:45:33 [INFO ]  Begin of epoch 0 :
2023-01-07 15:45:35 [ERROR]  Fatal error:
2023-01-07 15:45:35 [ERROR]  Traceback (most recent call last):
2023-01-07 15:45:35 [ERROR]    File "main.py", line 402, in <module>
2023-01-07 15:45:35 [ERROR]      main(options.get_state())
2023-01-07 15:45:35 [ERROR]    File "main.py", line 131, in main
2023-01-07 15:45:35 [ERROR]      steps = train_distilled_image.distill(state, state.models)
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/train_distilled_image.py", line 290, in distill
2023-01-07 15:45:35 [ERROR]      return Trainer(state, models).train()
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/train_distilled_image.py", line 221, in train
2023-01-07 15:45:35 [ERROR]      evaluate_steps(state, steps, 'Begin of epoch {}'.format(epoch))
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 288, in evaluate_steps
2023-01-07 15:45:35 [ERROR]      res = _evaluate_steps(test_nets_desc, reset=(state.test_nets_type == 'unknown_init'))
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 276, in _evaluate_steps
2023-01-07 15:45:35 [ERROR]      params = train_steps_inplace(state, models, steps, params, callback=test_callback)
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 65, in train_steps_inplace
2023-01-07 15:45:35 [ERROR]      callback(i, params)
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 268, in test_callback
2023-01-07 15:45:35 [ERROR]      test_loader_iter=test_loader_iter)
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 120, in evaluate_models
2023-01-07 15:45:35 [ERROR]      for i, (data, target) in enumerate(test_loader_iter):
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 219, in infinite_iterator
2023-01-07 15:45:35 [ERROR]      yield from iter(iterable)
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 637, in __next__
2023-01-07 15:45:35 [ERROR]      return self._process_next_batch(batch)
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 658, in _process_next_batch
2023-01-07 15:45:35 [ERROR]      raise batch.exc_type(batch.exc_msg)
2023-01-07 15:45:35 [ERROR]  RuntimeError: Traceback (most recent call last):
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
2023-01-07 15:45:35 [ERROR]      samples = collate_fn([dataset[i] for i in batch_indices])
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 232, in default_collate
2023-01-07 15:45:35 [ERROR]      return [default_collate(samples) for samples in transposed]
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 232, in <listcomp>
2023-01-07 15:45:35 [ERROR]      return [default_collate(samples) for samples in transposed]
2023-01-07 15:45:35 [ERROR]    File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 209, in default_collate
2023-01-07 15:45:35 [ERROR]      return torch.stack(batch, 0, out=out)
2023-01-07 15:45:35 [ERROR]  RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 4035 and 4003 in dimension 2 at /opt/conda/conda-bld/pytorch_1544199946412/work/aten/src/TH/generic/THTensorMoreMath.cpp:1333
Begin of epoch 0 (1 unknown_init nets):   0%|                                                                                                                                                                                                                                                          | 0/2 [00:01<?, ?it/s]Traceback (most recent call last):
  File "main.py", line 402, in <module>
    main(options.get_state())
  File "main.py", line 131, in main
    steps = train_distilled_image.distill(state, state.models)
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/train_distilled_image.py", line 290, in distill
    return Trainer(state, models).train()
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/train_distilled_image.py", line 221, in train
    evaluate_steps(state, steps, 'Begin of epoch {}'.format(epoch))
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 288, in evaluate_steps
    res = _evaluate_steps(test_nets_desc, reset=(state.test_nets_type == 'unknown_init'))
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 276, in _evaluate_steps
    params = train_steps_inplace(state, models, steps, params, callback=test_callback)
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 65, in train_steps_inplace
    callback(i, params)
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 268, in test_callback
    test_loader_iter=test_loader_iter)
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 120, in evaluate_models
    for i, (data, target) in enumerate(test_loader_iter):
  File "/home/wgs.wuerth.com/la00204104/dataset-distillation/basics.py", line 219, in infinite_iterator
    yield from iter(iterable)
  File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 637, in __next__
    return self._process_next_batch(batch)
  File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 658, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
RuntimeError: Traceback (most recent call last):
  File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 232, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 232, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/wgs.wuerth.com/la00204104/anaconda3/envs/new-distil/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 209, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 4035 and 4003 in dimension 2 at /opt/conda/conda-bld/pytorch_1544199946412/work/aten/src/TH/generic/THTensorMoreMath.cpp:1333

Thanks in advance

ssnl commented 1 year ago

Error message says that your samples are not of the same size. I am sorry but I can't provide help on custom datasets.

shraddha291996 commented 1 year ago

4035

Error message says that your samples are not of the same size. I am sorry but I can't provide help on custom datasets.

Thank you for your reply. sorry I am new in this, can you please tell from where can i find the tensor sizes of samples in the code?