Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.22k stars 198 forks source link

Distribution bug with autograd_ste_ops cpp extension #237

Closed vfdev-5 closed 3 years ago

vfdev-5 commented 3 years ago

Hi,

I'm installing brevitas from master as suggested in the README:

pip install git+https://github.com/Xilinx/brevitas.git

and installed package does not contain C++ files, so cpp extension loading wont work:

ls /opt/conda/lib/python3.8/site-packages/brevitas/
__init__.py  config.py  core  export  function  graph  inject  jit.py  loss  nn  onnx  proxy  quant_tensor.py  utils

and if I modify a bit init.py where the loading happens:

extensions_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'csrc')
sources = glob.glob(os.path.join(extensions_dir, '*.cpp'))
sources = [os.path.join(extensions_dir, s) for s in sources]

print("sources:", sources)
assert len(sources) > 0 and all([os.path.exists(p) for p in sources])

it fails as sources is empty:

BREVITAS_VERBOSE=1 python my_test_brevitas_script.py
sources: []
  File "/opt/conda/lib/python3.8/site-packages/brevitas/__init__.py", line 22, in <module>
    assert len(sources) > 0 and all([os.path.exists(p) for p in sources])
AssertionError
volcacius commented 3 years ago

Hello,

Yeah I messed up packaging, sorry about that. It should be fixed in the dev branch, can you give it a try? I'll propagate the fix to master sooner rather than later to avoid further issues. For what is worth, the cpp backend is useful only to get end-to-end compilation with the just-in-time compiler backend enabled (env BREVITAS_JIT=1, which is disabled by default) during distributed data parallel training (or similar distributed backends).

Alessandro

vfdev-5 commented 3 years ago

Thanks for the reply. Yes, I can try dev branch: https://github.com/Xilinx/brevitas/tree/dev

I was wondering why this cpp extension should be JIT and not built during the installation ? When using DDP, currently I have about 2x slowdown when training resnet18 8bits weights and activations vs plain resnet18 on 2 GPUs with NCCL. Do you think if using BREVITAS_JIT=1 and compiling properly C++ extension I could reduce the gap ?

volcacius commented 3 years ago

It's unfortunately complicated. There are a few things to know about this.

Until recently BREVITAS_JIT was enabled by default and the extension was built during the installation, rather than being compiled and loaded at runtime. What BREVITAS_JIT does is basically enable compilation of all quantizers with pytorch's just-in-time compiler. Depending on the model, it can provide large improvements both in terms of runtime as well as in terms of memory consumption. The cpp extensions is there because all quantizers depend on a few torch.autograd.Function, but the compiler still doesn't support them in python (only in cpp), so I need an alternate cpp implementation when I want the quantizers to be fully compiled. Without that (so having only some parts of the quantizer compiled) DistributedDataParallel would give an error (DDP requires the model to be clonable and cloning things that switch back and forth between python and native code is not supported in pytorch).

Pytorch's jit compiler is still somewhat rudimentary, so I spent a lot (really a lot) of time writing code in a way that would adhere to all the restrictions imposed by TorchScript (the subset of python supported by the compiler) while still making sure all the features I was interested in were supported. Really the whole library was basically designed around that. However, a few weeks ago I accidentally realized that in some circumstances (explained below) having BREVITAS_JIT enabled was making training diverge or converge to much lower accuracy, meaning that the interpreted vs compiled version of the code are giving very different results. This discovery made me question around 9 months of training experiments, where I was so sure I was working on good ideas and yet training would never converge decently no matter what, which honestly was very disappointing. So i took the (painful) decision of disabling BREVITAS_JIT by default and not recommending it anymore.

The numerical problems with the jit compiler arise you have ParameterFromRuntimeStatsScaling for your activation scaling (i.e. ScalingImplType.PARAMETER_FROM_STATS, the current default) rather than ParameterScaling (i.e. ScalingImplType.PARAMETER). The advantage ofParameterFromRuntimeStatsScaling is that it doesn't require any input from the user, while ParameterScaling requires to pass an initialization value. This makes quantization of pretrained floating point models much easier, since the user doesn't have to guess appropriate activation ranges, and it also means that in general the minimum amount of information that a user has to specify when doing QAT is just is the number of bits. In turns that means that I can build higher level APIs (that are still WIP) where quantization is performed at a graph level on the whole model in one line of code just by passing the number of bits for weights and activations. So ParameterFromRuntimeStatsScaling will stays as default implementation of activation scaling, but it also means that BREVITAS_JIT can't be enabled by default.

With BREVITAS_JIT disabled by default, it means most users won't need the need the cpp backend. Having it built at installation time means I have a dependency on pytorch in my setup.py. Because pytorch is not distributed on PyPI, that means I can't easily adhere to PE517/PEP518, which in turns creates restrictions on the way I can do packaging, CI/CD and so on. Having the extension built at runtime when needed removes this problem, meaning that in the long term I could switch to a pyproject.toml. It doesn't makes any difference in terms of performances, it's a just a matter of when is your C++ compiler invoked, and for most users it won't make any difference anyway since they won't use it.

Hope this explains it. Regarding enabling BREVITAS_JIT, again you should do it only by switching toscaling_impl_type = ScalingImplType.PARAMETER first in all your quantized activation layers. For a pretrained CNN with residual connections and batch norm like ResNet18 at 8 bit typically max_val = 20 is a good choice of init for unsigned activations.

Alessandro

vfdev-5 commented 3 years ago

Thanks Alessandro for a very detailed response !

vfdev-5 commented 3 years ago

I tried dev branch and it could compile cpp extension as expected. So, I could launch my DDP training with BREVITAS_JIT=1 (without changing scaling_impl_type). The result: test accuracy and training time are ~ similar as without JIT...

So, I was just wondering if it is expected that QAT training can be ~2 times slower vs f32 ?

volcacius commented 3 years ago

I'm glad to hear about your good results. Regarding slowdown, yes in general it's expected. QAT adds a lot of element-wise operations (with very low arithmetic intensity), so it's easy to get bandwidth bound. 2x is actually not that bad, I've trained quantized LSTM that went 10 to 100x slower. If you have a GPU with HBM you might see better results, but there will always be a slowdown. The issue is not so much with weight quantization as it is with activation quantization. Something that you can do to speedup your overall time to converge is to first train with only quantized weights, and then retrain with quantized weights and activations. At higher precision it doesn't make a huge difference, but at low precision on difficult topologies (say you are doing a 3b efficientnet) it can help getting there faster. You could do better with CUDA level optimized kernels, but in general the idea of Brevitas is to prioritize flexibility over speed. We train mainly for custom processors/accelerators, so you need the flexibility to model the hardware datapath as close as possible. The benefit that you should see with BREVITAS_JIT=1 is mainly reduced memory consumption, which means you can go for higher batch size. Make sure to use an up to date version of pytorch, memory consumption improved a lot between 1.1.0 and more recent versions.

Alessandro

vfdev-5 commented 3 years ago

Thanks a lot the for the explanation !