ben-hawks / pytorch-jet-classify

Pytorch implementation of the HLS4ML 3 Layer Jet Tagging model, including a standard Pytorch (float) and a Quantized (via Xilinx's Brevitas library) implementation. designed for use with the HLS4ML LHC Jet dataset (100 particles): https://zenodo.org/record/3602254
8 stars 8 forks source link

Understand brevitas.qnn.QuantReLU's `max_val` required parameter #4

Open ben-hawks opened 4 years ago

ben-hawks commented 4 years ago

Unlike stock PyTorch, Brevitas' qnn.QuantReLU has a required parameter specifying it's maximum value. In documentation and examples, it's set to 6 (making it effectively equal to a ReLU6 activation function), and at first glance doesn't seem to affect performance much, but this should be better understood and if possible (and makes sense to) worked around to mimic the unquantized implementation.

volcacius commented 4 years ago

Hello,

max_val has different implications depending on how scaling is implemented (which is driven by scaling_impl_type). By default QuantReLU uses ScalingImplType.PARAMETER, meaning that the underlying scale factor (== the training-time floating point that maps to 1) is a Pytorch nn.Parameter, learned with backprop. In that case, max_val is used (together with the number of bits) to compute the initialization for the nn.Parameter.

In your implementation, you are declaring a single "act", and then calling it multiple times in your forward pass. That means that the underlying nn.Parameter is learning to adapt to the numerical range required by all the different places you are calling it from. Given that there is no normalization in the topology, those numerical ranges might be very different, and the accuracy could suffer. Unless there is a specific reason hardware-wise to share that layer, I would try to declare separate activation layers.

If instead you used ScalingImplType.CONST, the scale factor would be a constant (again computed based on max_val and the number of bits), and then sharing the activation layer or having multiples wouldn't make any difference.

Finally, there is also an option to manually not pick a max_val at all. Let's say for example you are retraining from a pretrained floating-point network with no normalization, where the numerical ranges of the activations are all over the place. A single hand-pick max_val=6.0 is not gonna be a good initialization for all your nn.Parameters, and your network might not converge at all (especially at low precision on a complex topology+dataset). You can set scaling_impl_type = ScalingImplType.STATS, which basically learns the scale factor as some statistical function of your activation, where the function is specified by scaling_stats_op. This is done in the style of batch-norm, meaning that at training time you compute the function on the current batch of data and backprop through it, but you also update an exponential moving average with the result, that is then used at inference time. The default function is StatsOp.MEAN_LEARN_SIGMA_STD, which assumes that the incoming data is Gaussian distributed and computes mean + sigma * standard_deviation of it, with sigma learned and initialized based on scaling_stats_sigma (default = 2.0).
However, in my experience, accuracy-wise the best results are obtained when you do ScalingImplType.PARAMETER. The idea then, when you don't know how to init your scale factor, is to first do some iteration/epochs with ScalingImplStats.STATS, save the checkpoint, and then restart training with ScalingImplType.PARAMETER. In this case, you still have to specify a max_val (because of how it's currently implemented), but when you load the checkpoint Brevitas takes care of using the exponential moving average accumulated with STATS to initialize your PARAMETER, so basically max_val is ignored.

Hope it's clear, otherwise let me know.

Alessandro