IntelLabs / bayesian-torch

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch
BSD 3-Clause "New" or "Revised" License
483 stars 67 forks source link

Quantizatize bayesian cifar runtime error #27

Closed szhaoesat closed 9 months ago

szhaoesat commented 9 months ago

Hi, I am trying to run the quantization script with following command:

sh scripts/quantize_bayesian_cifar.sh

The error log is:

['resnet110', 'resnet20', 'resnet32', 'resnet44', 'resnet56']
Files already downloaded and verified
Files already downloaded and verified
Preparing model for quantization....
Calibrating...
Traceback (most recent call last):
  File "/esat/thalassa1/users/szhao/wrk/deep-learning/BNN/bayesian-torch/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py", line 618, in <module>
    main()
  File "/esat/thalassa1/users/szhao/wrk/deep-learning/BNN/bayesian-torch/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py", line 338, in main
    model_int8 = quantize(model, calib_loader, args)
  File "/esat/thalassa1/users/szhao/wrk/deep-learning/BNN/bayesian-torch/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py", line 577, in quantize
    _ = prepared_model(data)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 157, in forward
    raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu
ranganathkrishnan commented 9 months ago

Hi @szhaoesat, You seem to be performing model quantization in cuda based pytorch environment. Can you try quantization in cpuonly pytorch environment? - "conda install pytorch torchvision torchaudio cpuonly -c pytorch"

szhaoesat commented 9 months ago

Hi @ranganathkrishnan, thanks for your help! I tried the CPU version. However, it reported another error log:

['resnet110', 'resnet20', 'resnet32', 'resnet44', 'resnet56']
Files already downloaded and verified
Files already downloaded and verified
Preparing model for quantization....
Calibrating...
Calibration complete....
Traceback (most recent call last):
  File "/esat/thalassa1/users/szhao/wrk/deep-learning/BNN/bayesian-torch/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py", line 618, in <module>
    main()
  File "/esat/thalassa1/users/szhao/wrk/deep-learning/BNN/bayesian-torch/bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py", line 346, in main
    traced_model = torch.jit.trace(model_int8, data)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
    return self.module(*inputs, **kwargs)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File ".conda-env/envs/bnn/lib/python3.10/site-packages/bayesian_torch/models/deterministic/resnet.py", line 122, in forward
    out = F.avg_pool2d(out, out.size()[3])
TypeError: avg_pool2d(): argument 'kernel_size' (position 2) must be tuple of ints, not Tensor
ranganathkrishnan commented 9 months ago

Hi @szhaoesat, The issue seem to be from Jit trace model, I have pushed a fix f5c7126cb80ed7dc86b3c6dd55bc5c006d64e25a. Can you pull in this fix and try?

Thanks!

szhaoesat commented 9 months ago

Thanks for your help, now it works!