zkkli / I-ViT

[ICCV 2023] I-ViT: Integer-only Quantization for Efficient Vision Transformer Inference
Apache License 2.0
151 stars 13 forks source link

Pytorch inference and TVM inference gives different result! #12

Open dedoogong opened 2 months ago

dedoogong commented 2 months ago

Hi, thanks for your great work!

I've run QAT for deit base and got 84 mAP after some epochs. And then I converted the mdoel and run evaluate_accuracy.py but the result was totally different. why? can you give me some hint?

zkkli commented 2 months ago

Hi, I think it could be a problem with the version of TVM or Timm. For example, different timm versions may import modules with different names. Please try installing the version recommended in the README file.

dedoogong commented 2 months ago

Hi @zkkli ! thanks for your reply! I have installed TVM 0.9.dev0 ( python3 -c "import tvm; print(tvm.version)" results is 0.8.dev0). as you know, git tagging version and actual package version is different(actual version is 0.1 lower than tagging version).

and running evaulate_accuracy.py gives error like

Traceback (most recent call last):
  File "evaluate_accuracy.py", line 103, in <module>
    main()
  File "evaluate_accuracy.py", line 77, in main
    func, params = build_model.get_workload(name=name,
  File "/home/user/workspace/I-ViT2/TVM_benchmark/models/build_model.py", line 70, in get_workload
    return create_workload(net, QuantizeInitializer())
  File "/home/user/workspace/I-ViT2/TVM_benchmark/models/utils.py", line 163, in create_workload
    mod = relay.transform.InferType()(mod)
  File "/home/user/workspace/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/user/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  9: TVMFuncCall
  8: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  0: tvm::relay::qnn::QuantizeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/home/user/workspace/tvm/src/relay/analysis/type_solver.cc", line 622
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [15:44:49] /home/user/workspace/tvm/src/relay/qnn/op/quantize.cc:78: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) is false: Output type should be one of [int8, unit8, int32] but was int16

also I have cloned and checkout 0.10.dev0 again then rebuilt( python3 -c "import tvm; print(tvm.version)" results is 0.9.dev0). this time, it gives error like

Traceback (most recent call last):
  File "evaluate_accuracy.py", line 103, in <module>
    main()
  File "evaluate_accuracy.py", line 77, in main
    func, params = build_model.get_workload(name=name,
  File "/home/user/workspace/I-ViT2/TVM_benchmark/models/build_model.py", line 70, in get_workload
    return create_workload(net, QuantizeInitializer())
  File "/home/user/workspace/I-ViT2/TVM_benchmark/models/utils.py", line 163, in create_workload
    mod = relay.transform.InferType()(mod)
  File "/home/user/workspace/tvm_09dev0/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/user/workspace/tvm_09dev0/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  File "/home/user/workspace/tvm_09dev0/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: relay.concatenate requires all tensors have the same dtype

can you guess some solution? please help me~! if you want, I can upload the weight(pth.tar, npy) file. my cuda is 11.2( I heard 11.2 is compatible with tvm 0.9.dev0 ). timm version 0.9.7 but I do'nt know timm is related to this issue as the weight is overwritten by the QAT-ed weight.

Thank you very much!

dedoogong commented 2 months ago

Hi @zkkli , I tested this again with all tvm versions from 0.10.dev0 to 0.17.dev0 after clone separately and then checkout, source build. but all trials failed with the same error

  Check failed: (false) is false: relay.concatenate requires all tensors have the same dtype
zkkli commented 2 months ago

Hi, could you try aligning the timm version with the recommended one (0.4.12)?

dedoogong commented 2 months ago

Hi @zkkli , after I use timm 0.4.12 and some environmental changes(I dont' remember as I tried a lot of things), I could run evaulate_accuracy.py again. But! it still gives different results. (cat image) Pytorch : top5 labels: [285 282 356 281 278] Rank 1: 'Egyptian cat' Rank 2: 'tiger cat' Rank 3: 'weasel' Rank 4: 'tabby' Rank 5: 'kit fox'

TVM :
TVM top5 labels: [463 600 412 792 840] Rank 1: bucket Rank 2: hook Rank 3: ashcan Rank 4: shovel Rank 5: swab

by the way, the original I-ViT pytorch code gives error like

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
        size mismatch for blocks.0.norm1.norm_scaling_factor: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for blocks.0.norm2.norm_scaling_factor: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for blocks.1.norm1.norm_scaling_factor: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for blocks.1.norm2.norm_scaling_factor: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1]).
....

so I modified IntLayerNorm's norm scaling factor channel size from 1 to 768.

class IntLayerNorm(nn.LayerNorm): 
...
        self.register_buffer('norm_scaling_factor', torch.zeros(768))

and there was no errors when I run convert_model.py

dedoogong commented 2 months ago

Hi @zkkli , it would be nice to share your checkpoint.pth.tar and params.npy for me to debug it. Thank you~!

dedoogong commented 1 month ago

Hi @zkkli , is there any upate? there is same issue(https://github.com/zkkli/I-ViT/issues/6) thank you!