yoshitomo-matsubara / torchdistill

A coding-free framework built on PyTorch for reproducible deep learning studies. 🏆25 knowledge distillation methods presented at CVPR, ICLR, ECCV, NeurIPS, ICCV, etc are implemented so far. 🎁 Trained models, training logs and configurations are available for ensuring the reproducibiliy and benchmark.
https://yoshitomo-matsubara.net/torchdistill/
MIT License
1.37k stars 132 forks source link

[BUG] fp16 causes AssertionError: No inf checks were recorded for this optimizer #386

Closed jsrdcht closed 1 year ago

jsrdcht commented 1 year ago

Describe the bug I modified the examples/legacy/image_classification.py to adapt to huggingface accelerate , meeting the following question:

Traceback (most recent call last):
  File "examples/legacy/image_classification_accelerate.py", line 217, in <module>
    main(argparser.parse_args())
  File "examples/legacy/image_classification_accelerate.py", line 198, in main
    train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
  File "examples/legacy/image_classification_accelerate.py", line 129, in train
    train_one_epoch(training_box, device, epoch, log_freq)
  File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch
    training_box.update_params(loss)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params
    self.optimizer.step()
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step
    self.scaler.step(self.optimizer, closure)
  File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

To Reproduce Provide

  1. Exact command to run your code accelerate launch examples/legacy/image_classification_accelerate.py --config /workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml
  2. Whether or not you made any changes in Python code (if so, how you made the changes?) I have enabled the fp16 multi-gpu option in the configuration file of accelerate. My main experiment configuration file is for the AT algorithm. I made some modifications to the image_classification file, mainly following the modifications made to the text_classification.py file by the author. I did not make any personalized changes and simply followed the approach of text_classification.py with minimal modifications, which ultimately led to this error.
  3. YAML config file
    
    datasets:
    ilsvrc2012:
    name: &dataset_name 'ilsvrc2012'
    type: 'ImageFolder'
    root: &root_dir !join ['/workspace/sync/imagenet-1k']
    splits:
      train:
        dataset_id: &imagenet_train !join [*dataset_name, '/train']
        params:
          root: !join [*root_dir, '/train']
          transform_params:
            - type: 'RandomResizedCrop'
              params:
                size: &input_size [224, 224]
            - type: 'RandomHorizontalFlip'
              params:
                p: 0.5
            - &totensor
              type: 'ToTensor'
              params:
            - &normalize
              type: 'Normalize'
              params:
                mean: [0.485, 0.456, 0.406]
                std: [0.229, 0.224, 0.225]
      val:
        dataset_id: &imagenet_val !join [*dataset_name, '/val']
        params:
          root: !join [*root_dir, '/val']
          transform_params:
            - type: 'Resize'
              params:
                size: 256
            - type: 'CenterCrop'
              params:
                size: *input_size
            - *totensor
            - *normalize

models: teacher_model: name: &teacher_model_name 'maskedvit_base_patch16_224' params: num_classes: 1000 pretrained: True mask_ratio: 0.0 experiment: &teacher_experiment !join [dataset_name, '-', teacher_model_name] ckpt: !join ['./resource/ckpt/ilsvrc2012/teacher/', teacher_experiment, '.pt'] student_model: name: &student_model_name 'maskedvit_base_patch16_224' params: num_classes: 1000 pretrained: False mask_ratio: 0.5 experiment: &student_experiment !join [dataset_name, '-', student_model_name, 'from', teacher_model_name] ckpt: !join ['./imagenet/mask_distillation/', *student_experiment, '.pt']

train: log_freq: 1000 num_epochs: 100 train_data_loader: dataset_id: imagenet_train random_sample: True batch_size: 64 num_workers: 16 cache_output: val_data_loader: dataset_id: imagenet_val random_sample: False batch_size: 128 num_workers: 16 teacher: sequential: [] forward_hook: input: [] output: ['mask_filter'] wrapper: 'DataParallel' requires_grad: False student: adaptations: sequential: [] frozen_modules: [] forward_hook: input: [] output: ['mask_filter'] wrapper: 'DistributedDataParallel' requires_grad: True optimizer: type: 'SGD' grad_accum_step: 16 max_grad_norm: 5.0 module_wise_params:

test: test_data_loader: dataset_id: *imagenet_val random_sample: False batch_size: 1 num_workers: 16

6. Log file

(pytorch_1) root@baa8ef5448b2:/workspace/sync/torchdistill# accelerate launch examples/legacy/image_classification_accelerate.py --config /workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml 2023/08/15 02:49:09 INFO main Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1) 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 0 2023/08/15 02:49:09 INFO main Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1) 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 1 2023/08/15 02:49:09 INFO main Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1) 2023/08/15 02:49:09 INFO main Namespace(adjust_lr=False, config='/workspace/sync/torchdistill/configs/legacy/official/ilsvrc2012/yoshitomo-matsubara/rrpr2020/at-vit-base_from_vit-base.yaml', device='cuda', dist_url='env://', log=None, log_config=False, seed=None, start_epoch=0, student_only=False, test_only=False, world_size=1) 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 2 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Added key: store_based_barrier_key:1 to store for rank: 3 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. 2023/08/15 02:49:09 INFO torch.distributed.distributed_c10d Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. 2023/08/15 02:49:09 INFO main Distributed environment: MULTI_GPU Backend: nccl Num processes: 4 Process index: 0 Local process index: 0 Device: cuda:0

Mixed precision type: fp16

2023/08/15 02:49:09 INFO torchdistill.datasets.util Loading train data 2023/08/15 02:49:12 INFO torchdistill.datasets.util dataset_id ilsvrc2012/train: 2.874385356903076 sec 2023/08/15 02:49:12 INFO torchdistill.datasets.util Loading val data 2023/08/15 02:49:12 INFO torchdistill.datasets.util dataset_id ilsvrc2012/val: 0.12787175178527832 sec 2023/08/15 02:49:15 INFO timm.models._builder Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k) 2023/08/15 02:49:16 INFO timm.models._hub [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors. 2023/08/15 02:49:16 INFO torchdistill.common.main_util ckpt file is not found at ./resource/ckpt/ilsvrc2012/teacher/ilsvrc2012-maskedvit_base_patch16_224.pt 2023/08/15 02:49:18 INFO torchdistill.common.main_util ckpt file is not found at ./imagenet/mask_distillation/ilsvrc2012-maskedvit_base_patch16_224_from_maskedvit_base_patch16_224.pt 2023/08/15 02:49:18 INFO main Start training 2023/08/15 02:49:18 INFO torchdistill.models.util [teacher model] 2023/08/15 02:49:18 INFO torchdistill.models.util Using the original teacher model 2023/08/15 02:49:18 INFO torchdistill.models.util [student model] 2023/08/15 02:49:18 INFO torchdistill.models.util Using the original student model 2023/08/15 02:49:18 INFO torchdistill.core.distillation Loss = 1.0 OrgLoss + 1.0 GenerativeKDLoss( (cross_entropy_loss): CrossEntropyLoss() (SmoothL1Loss): SmoothL1Loss() ) 2023/08/15 02:49:18 INFO torchdistill.core.distillation Freezing the whole teacher model 2023/08/15 02:49:18 INFO torchdistill.common.module_util None of None could not be reached in DataParallel /root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The use_fp16 property is deprecated and will be removed in version 1.0 of Accelerate use AcceleratorState.mixed_precision == 'fp16' instead. warnings.warn( /root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The use_fp16 property is deprecated and will be removed in version 1.0 of Accelerate use AcceleratorState.mixed_precision == 'fp16' instead. warnings.warn( /root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The use_fp16 property is deprecated and will be removed in version 1.0 of Accelerate use AcceleratorState.mixed_precision == 'fp16' instead. warnings.warn( /root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/state.py:802: FutureWarning: The use_fp16 property is deprecated and will be removed in version 1.0 of Accelerate use AcceleratorState.mixed_precision == 'fp16' instead. warnings.warn( 2023/08/15 02:49:24 INFO torchdistill.misc.log Epoch: [0] [ 0/5005] eta: 8:39:24 lr: 0.001 img/s: 21.99282017795937 loss: 0.4513 (0.4513) time: 6.2267 data: 3.3162 max mem: 8400 2023/08/15 02:49:24 INFO torch.nn.parallel.distributed Reducer buckets have been rebuilt in this iteration. 2023/08/15 02:49:24 INFO torch.nn.parallel.distributed Reducer buckets have been rebuilt in this iteration. Traceback (most recent call last): File "examples/legacy/image_classification_accelerate.py", line 217, in Traceback (most recent call last): File "examples/legacy/image_classification_accelerate.py", line 217, in main(argparser.parse_args()) File "examples/legacy/image_classification_accelerate.py", line 198, in main train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator) File "examples/legacy/image_classification_accelerate.py", line 129, in train train_one_epoch(training_box, device, epoch, log_freq) File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch main(argparser.parse_args()) File "examples/legacy/image_classification_accelerate.py", line 198, in main training_box.update_params(loss) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator) File "examples/legacy/image_classification_accelerate.py", line 129, in train self.optimizer.step()
train_one_epoch(training_box, device, epoch, log_freq) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step

File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch training_box.update_params(loss) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params self.optimizer.step() File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step self.scaler.step(self.optimizer, closure)self.scaler.step(self.optimizer, closure)

File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."

AssertionErrorAssertionError: : No inf checks were recorded for this optimizer.No inf checks were recorded for this optimizer.

Traceback (most recent call last): File "examples/legacy/image_classification_accelerate.py", line 217, in main(argparser.parse_args()) File "examples/legacy/image_classification_accelerate.py", line 198, in main train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator) File "examples/legacy/image_classification_accelerate.py", line 129, in train train_one_epoch(training_box, device, epoch, log_freq) File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch training_box.update_params(loss) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params self.optimizer.step() File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step self.scaler.step(self.optimizer, closure) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." AssertionError: No inf checks were recorded for this optimizer. Traceback (most recent call last): File "examples/legacy/image_classification_accelerate.py", line 217, in main(argparser.parse_args()) File "examples/legacy/image_classification_accelerate.py", line 198, in main train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator) File "examples/legacy/image_classification_accelerate.py", line 129, in train train_one_epoch(training_box, device, epoch, log_freq) File "examples/legacy/image_classification_accelerate.py", line 71, in train_one_epoch training_box.update_params(loss) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torchdistill/core/distillation.py", line 316, in update_params self.optimizer.step() File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/optimizer.py", line 133, in step self.scaler.step(self.optimizer, closure) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 339, in step assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." AssertionError: No inf checks were recorded for this optimizer. ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 3701268) of binary: /root/miniconda3/envs/pytorch_1/bin/python Traceback (most recent call last): File "/root/miniconda3/envs/pytorch_1/bin/accelerate", line 8, in sys.exit(main()) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main args.func(args) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/launch.py", line 970, in launch_command multi_gpu_launcher(args) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/accelerate/commands/launch.py", line 646, in multi_gpu_launcher distrib_run.run(args) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run elastic_launch( File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/root/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

examples/legacy/image_classification_accelerate.py FAILED

Failures: [1]: time : 2023-08-15_02:49:37 host : baa8ef5448b2 rank : 1 (local_rank: 1) exitcode : 1 (pid: 3701269) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html [2]: time : 2023-08-15_02:49:37 host : baa8ef5448b2 rank : 2 (local_rank: 2) exitcode : 1 (pid: 3701270) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html [3]: time : 2023-08-15_02:49:37 host : baa8ef5448b2 rank : 3 (local_rank: 3) exitcode : 1 (pid: 3701271) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause (first observed failure): [0]: time : 2023-08-15_02:49:37 host : baa8ef5448b2 rank : 0 (local_rank: 0) exitcode : 1 (pid: 3701268) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

**Expected behavior**
A clear and concise description of what you expected to happen.

**Environment (please complete the following information):**
 - OS: Ubuntu 22.04 LTS
 - Python ver.3.8
 - torchdistill ver. v0.3.3

(pytorch_1) root@baa8ef5448b2:/workspace/sync/torchdistill# conda list

packages in environment at /root/miniconda3/envs/pytorch_1:

#

Name Version Build Channel

_libgcc_mutex 0.1 main defaults _openmp_mutex 5.1 1_gnu defaults accelerate 0.21.0 pypi_0 pypi blas 1.0 mkl defaults brotlipy 0.7.0 py38h27cfd23_1003 defaults bzip2 1.0.8 h7b6447c_0 defaults ca-certificates 2023.05.30 h06a4308_0 defaults certifi 2023.7.22 py38h06a4308_0 defaults cffi 1.15.1 py38h5eee18b_3 defaults charset-normalizer 2.0.4 pyhd3eb1b0_0 defaults contourpy 1.1.0 pypi_0 pypi cryptography 41.0.2 py38h22a60cf_0 defaults cuda-cudart 11.7.99 0 nvidia cuda-cupti 11.7.101 0 nvidia cuda-libraries 11.7.1 0 nvidia cuda-nvrtc 11.7.99 0 nvidia cuda-nvtx 11.7.91 0 nvidia cuda-runtime 11.7.1 0 nvidia cycler 0.11.0 pypi_0 pypi cython 3.0.0 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch filelock 3.12.2 pypi_0 pypi fonttools 4.42.0 pypi_0 pypi freetype 2.12.1 h4a9f257_0 defaults fsspec 2023.6.0 pypi_0 pypi future 0.18.3 py38h06a4308_0 defaults giflib 5.2.1 h5eee18b_3 defaults gmp 6.2.1 h295c915_3 defaults gnutls 3.6.15 he1e5248_0 defaults huggingface-hub 0.16.4 pypi_0 pypi idna 3.4 py38h06a4308_0 defaults importlib-resources 6.0.1 pypi_0 pypi intel-openmp 2023.1.0 hdb19cb5_46305 defaults jpeg 9e h5eee18b_1 defaults kiwisolver 1.4.4 pypi_0 pypi lame 3.100 h7b6447c_0 defaults lcms2 2.12 h3be6417_0 defaults ld_impl_linux-64 2.38 h1181459_1 defaults lerc 3.0 h295c915_0 defaults libcublas 11.10.3.66 0 nvidia libcufft 10.7.2.124 h4fbf590_0 nvidia libcufile 1.7.1.12 0 nvidia libcurand 10.3.3.129 0 nvidia libcusolver 11.4.0.1 0 nvidia libcusparse 11.7.4.91 0 nvidia libdeflate 1.17 h5eee18b_0 defaults libffi 3.4.4 h6a678d5_0 defaults libgcc-ng 11.2.0 h1234567_1 defaults libgfortran-ng 11.2.0 h00389a5_1 defaults libgfortran5 11.2.0 h1234567_1 defaults libgomp 11.2.0 h1234567_1 defaults libiconv 1.16 h7f8727e_2 defaults libidn2 2.3.4 h5eee18b_0 defaults libnpp 11.7.4.75 0 nvidia libnvjpeg 11.8.0.2 0 nvidia libopenblas 0.3.21 h043d6bf_0 defaults libpng 1.6.39 h5eee18b_0 defaults libprotobuf 3.20.3 he621ea3_0 defaults libstdcxx-ng 11.2.0 h1234567_1 defaults libtasn1 4.19.0 h5eee18b_0 defaults libtiff 4.5.0 h6a678d5_2 defaults libunistring 0.9.10 h27cfd23_0 defaults libwebp 1.2.4 h11a3e52_1 defaults libwebp-base 1.2.4 h5eee18b_1 defaults lz4-c 1.9.4 h6a678d5_0 defaults matplotlib 3.7.2 pypi_0 pypi mkl 2023.1.0 h213fc3f_46343 defaults mkl-service 2.4.0 py38h5eee18b_1 defaults mkl_fft 1.3.6 py38h417a72b_1 defaults mkl_random 1.2.2 py38h417a72b_1 defaults ncurses 6.4 h6a678d5_0 defaults nettle 3.7.3 hbbd107a_1 defaults ninja 1.10.2 h06a4308_5 defaults ninja-base 1.10.2 hd09550d_5 defaults numpy 1.24.3 py38hf6e8229_1 defaults numpy-base 1.24.3 py38h060ed82_1 defaults openh264 2.1.1 h4ff587b_0 defaults openssl 3.0.10 h7f8727e_0 defaults packaging 23.1 pypi_0 pypi pillow 9.4.0 py38h6a678d5_0 defaults pip 23.2.1 py38h06a4308_0 defaults psutil 5.9.5 pypi_0 pypi pycocotools 2.0.6 pypi_0 pypi pycparser 2.21 pyhd3eb1b0_0 defaults pyopenssl 23.2.0 py38h06a4308_0 defaults pyparsing 3.0.9 pypi_0 pypi pysocks 1.7.1 py38h06a4308_0 defaults python 3.8.17 h955ad1f_0 defaults python-dateutil 2.8.2 pypi_0 pypi pytorch 1.13.0 py3.8_cuda11.7_cudnn8.5.0_0 pytorch pytorch-cuda 11.7 h778d358_5 pytorch pytorch-mutex 1.0 cuda pytorch pyyaml 6.0 py38h5eee18b_1 defaults readline 8.2 h5eee18b_0 defaults requests 2.31.0 py38h06a4308_0 defaults safetensors 0.3.2 pypi_0 pypi scipy 1.10.1 pypi_0 pypi setuptools 68.0.0 py38h06a4308_0 defaults six 1.16.0 pypi_0 pypi sqlite 3.41.2 h5eee18b_0 defaults tbb 2021.8.0 hdb19cb5_0 defaults timm 0.9.5 pypi_0 pypi tk 8.6.12 h1ccaba5_0 defaults torchaudio 0.13.0 py38_cu117 pytorch torchdistill 0.3.3 pypi_0 pypi torchvision 0.14.0 py38_cu117 pytorch tqdm 4.66.1 pypi_0 pypi typing-extensions 4.7.1 py38h06a4308_0 defaults typing_extensions 4.7.1 py38h06a4308_0 defaults urllib3 1.26.16 py38h06a4308_0 defaults wheel 0.38.4 py38h06a4308_0 defaults xz 5.4.2 h5eee18b_0 defaults yaml 0.2.5 h7b6447c_0 defaults zipp 3.16.2 pypi_0 pypi zlib 1.2.13 h5eee18b_0 defaults zstd 1.5.5 hc292b87_0 defaults



**Additional context**
Add any other context about the problem here.
yoshitomo-matsubara commented 1 year ago

Hi @jsrdcht

Since you made changes in code and did not share the actual code, I cannot confirm that this is a bug from torchdistill.

If you're trying to introduce new components and still at trial-and-error phase, please use Discussions instead, and provide your modified code. You also still keep this discussion unanswered for this topic.

From the error message, I guess that you didn't pass accelerator when instantiating distillation_box or training_box, and then missed the following lines to be executed inside distillation_box or training_box https://github.com/yoshitomo-matsubara/torchdistill/blob/v0.3.3/torchdistill/core/distillation.py#L306-L307

jsrdcht commented 1 year ago

Here's my code. Not because of the reason you guessed, I have debugged it and saw that line of code being executed.

import argparse
import datetime
import logging
import os
import time

import torch
import vit
import custom_loss

from accelerate import Accelerator, DistributedType

from torch import distributed as dist
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from torchdistill.common import file_util, yaml_util, module_util
from torchdistill.common.constant import def_logger
from torchdistill.common.main_util import is_main_process, init_distributed_mode, load_ckpt, save_ckpt, set_seed, setup_for_distributed
from torchdistill.core.distillation import get_distillation_box
from torchdistill.core.training import get_training_box
from torchdistill.datasets import util
from torchdistill.eval.classification import compute_accuracy
from torchdistill.misc.log import setup_log_file, SmoothedValue, MetricLogger
from torchdistill.models.official import get_image_classification_model
from torchdistill.models.registry import get_model

logger = def_logger.getChild(__name__)

def get_argparser():
    parser = argparse.ArgumentParser(description='Knowledge distillation for image classification models')
    parser.add_argument('--config', required=True, help='yaml file path')
    parser.add_argument('--device', default='cuda', help='device')
    parser.add_argument('--log', help='log file path')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
    parser.add_argument('--seed', type=int, help='seed in random number generator')
    parser.add_argument('-test_only', action='store_true', help='only test the models')
    parser.add_argument('-student_only', action='store_true', help='test the student model only')
    parser.add_argument('-log_config', action='store_true', help='log config')
    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('-adjust_lr', action='store_true',
                        help='multiply learning rate by number of distributed processes (world_size)')
    return parser

def load_model(model_config, device, distributed):
    model = get_image_classification_model(model_config, distributed)
    if model is None:
        repo_or_dir = model_config.get('repo_or_dir', None)
        model = get_model(model_config['name'], repo_or_dir, **model_config['params'])

    ckpt_file_path = model_config['ckpt']
    load_ckpt(ckpt_file_path, model=model, strict=True)
    return model.to(device)

def train_one_epoch(training_box, device, epoch, log_freq):
    metric_logger = MetricLogger(delimiter='  ')
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}'))
    header = 'Epoch: [{}]'.format(epoch)
    for sample_batch, targets, supp_dict in \
            metric_logger.log_every(training_box.train_data_loader, log_freq, header):
        start_time = time.time()
        sample_batch, targets = sample_batch.to(device), targets.to(device)
        loss = training_box(sample_batch, targets, supp_dict)
        training_box.update_params(loss)
        batch_size = sample_batch.shape[0]
        metric_logger.update(loss=loss.item(), lr=training_box.optimizer.param_groups[0]['lr'])
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        if (torch.isnan(loss) or torch.isinf(loss)) and is_main_process():
            raise ValueError('The training loop was broken due to loss = {}'.format(loss))

@torch.inference_mode()
def evaluate(model, data_loader, device, device_ids, distributed, log_freq=1000, title=None, header='Test:'):
    model.to(device)
    if distributed:
        model = DistributedDataParallel(model, device_ids=device_ids)
    elif device.type.startswith('cuda'):
        model = DataParallel(model, device_ids=device_ids)

    if title is not None:
        logger.info(title)

    model.eval()
    metric_logger = MetricLogger(delimiter='  ')
    for image, target in metric_logger.log_every(data_loader, log_freq, header):
        image = image.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        output = model(image)
        acc1, acc5 = compute_accuracy(output, target, topk=(1, 5))
        # FIXME need to take into account that the datasets
        # could have been padded in distributed setup
        batch_size = image.shape[0]
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    top1_accuracy = metric_logger.acc1.global_avg
    top5_accuracy = metric_logger.acc5.global_avg
    logger.info(' * Acc@1 {:.4f}\tAcc@5 {:.4f}\n'.format(top1_accuracy, top5_accuracy))
    return metric_logger.acc1.global_avg

def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor, accelerator) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor, accelerator = accelerator)
    best_val_top1_accuracy = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_top1_accuracy = evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed,
                                     log_freq=log_freq, header='Validation:')
        if val_top1_accuracy > best_val_top1_accuracy and is_main_process():
            logger.info('Best top-1 accuracy: {:.4f} -> {:.4f}'.format(best_val_top1_accuracy, val_top1_accuracy))
            logger.info('Updating ckpt at {}'.format(ckpt_file_path))
            best_val_top1_accuracy = val_top1_accuracy
            if distributed is False and accelerator is not None:
                student_model_without_ddp = accelerator.unwrap_model(student_model)
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_top1_accuracy, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()

def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    logger.info(args)
    cudnn.benchmark = True
    set_seed(args.seed)

    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))

    # distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url)
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    accelerator = Accelerator(mixed_precision='fp16')
    distributed = accelerator.state.distributed_type == DistributedType.MULTI_GPU
    device_ids = [accelerator.device.index]
    if distributed:
        setup_for_distributed(is_main_process())

    logger.info(accelerator.state)
    device = accelerator.device

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)

    # device = torch.device(args.device)
    dataset_dict = util.get_all_datasets(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config.get('teacher_model', None)
    teacher_model =\
        load_model(teacher_model_config, device, distributed) if teacher_model_config is not None else None
    student_model_config =\
        models_config['student_model'] if 'student_model' in models_config else models_config['model']
    ckpt_file_path = student_model_config['ckpt']
    student_model = load_model(student_model_config, device, distributed)
    if accelerator is not None:
        student_model, teacher_model = accelerator.prepare(student_model, teacher_model)
        for name, dataset in dataset_dict.items():
            dataset_dict[name] = accelerator.prepare(dataset)
    if args.log_config:
        logger.info(config)

    if not args.test_only:
        train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args, accelerator)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(dataset_dict[test_data_loader_config['dataset_id']],
                                              test_data_loader_config, distributed)
    log_freq = test_config.get('log_freq', 1000)
    if not args.student_only and teacher_model is not None:
        evaluate(teacher_model, test_data_loader, device, device_ids, distributed, log_freq=log_freq,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model, test_data_loader, device, device_ids, distributed, log_freq=log_freq,
             title='[Student: {}]'.format(student_model_config['name']))

if __name__ == '__main__':
    argparser = get_argparser()
    main(argparser.parse_args())
yoshitomo-matsubara commented 1 year ago
    module_wise_params:
      - params: ['mask_token', 'cls_token', 'pos_embed']
        is_teacher: None
        module: None
        weight_decay: 0.0

Your module_wise_params entry looks broken, and use is_teacher: False instead of None (or you can skip is_teacher, as the default value is False)

See https://github.com/yoshitomo-matsubara/torchdistill/blob/main/configs/legacy/sample/pascal_voc2012/ce/deeplabv3_resnet101.yaml#L102-L109 for the format

Again, please use Discussions for this kind of question since it doesn't look like a bug from torchdistill

jsrdcht commented 1 year ago

It is the issues you marked that caused the strange errors. I also made some changes in the source code to adapt to my configuration.

Here are some problems that exist in the source code:

def pre_process(self, epoch=None, **kwargs):
        clear_io_dict(self.teacher_io_dict)
        clear_io_dict(self.student_io_dict)
        self.teacher_model.eval()
        self.student_model.train()
        if self.distributed and self.accelerator is None: # batch_sampler.sampler is only valid for ddp without accelerator
            self.train_data_loader.batch_sampler.sampler.set_epoch(epoch)
# Set up accelerator if necessary
        if self.accelerator is not None:
            if self.teacher_updatable:
                self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
                    self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
                                             self.train_data_loader, self.val_data_loader)
            else:
                # self.teacher_model = self.teacher_model.to(self.accelerator.device)
                # if self.accelerator.state.use_fp16:
                #     self.teacher_model = self.teacher_model.half()

               # sice fp16 is took by accelerate, we have to warp the teacher model otherwise the input can't be casted to fp16
                self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
                    self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
                                             self.train_data_loader, self.val_data_loader)
module_wise_params_configs = optim_config.get('module_wise_params', list())
            if len(module_wise_params_configs) > 0:
                trainable_module_list = list()
                for module_wise_params_config in module_wise_params_configs:
                    module_wise_params_dict = dict()
                    if isinstance(module_wise_params_config.get('params', None), dict):
                        module_wise_params_dict.update(module_wise_params_config['params'])

                    if 'lr' in module_wise_params_dict:
                        module_wise_params_dict['lr'] *= self.lr_factor

                    target_model = \
                        self.teacher_model if module_wise_params_config.get('is_teacher', False) else self.student_model
                    module = get_module(target_model, module_wise_params_config['module'])
                    # support for nn.Parameter()
                    module_wise_params_dict['params'] = module.parameters() if isinstance(module, nn.Module) else [module]
                    trainable_module_list.append(module_wise_params_dict)
            else:
                trainable_module_list = nn.ModuleList([self.student_model])
                if self.teacher_updatable:
                    logger.info('Note that you are training some/all of the modules in the teacher model')
                    trainable_module_list.append(self.teacher_model)