Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.2k stars 196 forks source link

Problems about training and access the weight. #117

Closed tangyuelm closed 4 years ago

tangyuelm commented 4 years ago

Hello, thanks so much for the tool. I am trying to use this tool for quantized Vgg. However, the results have no differences when I change the bit width. I am sure I used QuantType.INT in the common.py file. I try to access the int_weight property as some issues mentioned but I failed to access anything. I add a function like that in the quant_conv.py file. def save_weight(): print(QuantConv2d.int_weight) I changed the common.py file as import brevitas.nn as qnn from brevitas.core.quant import QuantType from brevitas.core.restrict_val import RestrictValueType from brevitas.core.scaling import ScalingImplType from brevitas.core.stats import StatsOp from brevitas.nn.quant_conv import save_weight

QUANT_TYPE = QuantType.INT SCALING_MIN_VAL = 2e-16

ACT_SCALING_IMPL_TYPE = ScalingImplType.PARAMETER ACT_SCALING_PER_CHANNEL = False ACT_SCALING_RESTRICT_SCALING_TYPE = RestrictValueType.LOG_FP ACT_MAX_VAL = 6.0 ACT_RETURN_QUANT_TENSOR = False ACT_PER_CHANNEL_BROADCASTABLE_SHAPE = None HARD_TANH_THRESHOLD = 10.0

WEIGHT_SCALING_IMPL_TYPE = ScalingImplType.STATS WEIGHT_SCALING_PER_OUTPUT_CHANNEL = True WEIGHT_SCALING_STATS_OP = StatsOp.MAX WEIGHT_RESTRICT_SCALING_TYPE = RestrictValueType.LOG_FP WEIGHT_NARROW_RANGE = True

ENABLE_BIAS_QUANT = False

HADAMARD_FIXED_SCALE = False

def make_quant_conv2d(in_channels, out_channels, kernel_size, stride, padding, groups, bias, bit_width, enable_bias_quant=ENABLE_BIAS_QUANT, weight_quant_type=QUANT_TYPE, weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE, weight_scaling_stats_op=WEIGHT_SCALING_STATS_OP, weight_scaling_per_output_channel=WEIGHT_SCALING_PER_OUTPUT_CHANNEL, weight_restrict_scaling_type=WEIGHT_RESTRICT_SCALING_TYPE, weight_narrow_range=WEIGHT_NARROW_RANGE, weight_scaling_min_val=SCALING_MIN_VAL): bias_quant_type = QUANT_TYPE if enable_bias_quant else QuantType.FP

x= qnn.QuantConv2d(in_channels,
                       out_channels,
                       groups=groups,
                       kernel_size=kernel_size,
                       padding=padding,
                       stride=stride,
                       bias=bias,
                       bias_quant_type=bias_quant_type,
                       compute_output_bit_width=bias and enable_bias_quant,
                       compute_output_scale=bias and enable_bias_quant,
                       weight_bit_width=bit_width,
                       weight_quant_type=weight_quant_type,
                       weight_scaling_impl_type=weight_scaling_impl_type,
                       weight_scaling_stats_op=weight_scaling_stats_op,
                       weight_scaling_per_output_channel=weight_scaling_per_output_channel,
                       weight_restrict_scaling_type=weight_restrict_scaling_type,
                       weight_narrow_range=weight_narrow_range,
                       weight_scaling_min_val=weight_scaling_min_val)
save_weight
return x

But nothing is printed. So could you please show me a simple example of extracting the weights? And do you know what might be the problems for the same training accuracy?

tangyuelm commented 4 years ago

Besides, I see the training example of BNN has the "self.model.clipweights(-1,1)" function in the trainer. If I want to train a quantized model, do I also need to add the function in the training? If so, how to choose the value for the clamp(min_val, max_val)?

volcacius commented 4 years ago

Hello,

def save_weight():
print(QuantConv2d.int_weight)

This is not correct, if you are defining it inside QuantConv2d you are should call print(self.int_weight).

Regarding self.model.clip_weights(-1,1), this is typical when you are training with CONST scaling on weights, as it is the case for the BNN example. The valus to pick come from the fact that when quantizing with weight_scaling_const=1, you are always quantizing to the interval [-1,1]. Clipping the floating-point copy of the weights helps with convergence.

Regarding accuracy, it's hard to say what you are observing without more information, but in general if quant_type is set to QuantType.INT, you can be sure your weights are being quantized.

tangyuelm commented 4 years ago

quantize.zip Thanks for you quick reply. In fact, I define it outside the class before. Now I move it inside and nothing is printed. I have attached my code. Now the Vggq8.py it my main code. I am wondering whether I make some misoperation of the brevitas tools.

volcacius commented 4 years ago

Hello,

Line 66 of common.py should be x.save_weight() to print the weights since it's a method, not x.save_weight. Also note that this way you are printing the weights only when you are initializing the layer. If you want to see the final value you should call the method also at the end of the training.

tangyuelm commented 4 years ago

Thank you very much. I can see the weights now. I will try change the code to see the final value and see what influence my accuracy.

tangyuelm commented 4 years ago

I tested more bit width on VGG16 on cifar10, the best accuracy in 400 epochs for int8 is 91.78, for int4 is 91.49, for int3 is 91.06, and for int2 it cannot converge. Are the results normal? I thought even the accuracy in different epochs(after converge) are bigger than that?

volcacius commented 4 years ago

I would first train in floating-point to get a reference. Obviously the choice of hyperparameters can make a huge difference.

tangyuelm commented 4 years ago

I tested the float-point before and it is also above 91. In fact, my question is that is it normal since I have quantized it down to 3 bits but it can remain similar accuracy? Because my original understanding is that quantization will decrease accuracy, and maybe 8 int can remain similar accuracy but lower bitwidth will have some difference.

volcacius commented 4 years ago

VGG16 is a highly overparametrized model, cifar10 is a very small dataset, and Brevitas implements very good quantization techniques, so I would expect the difference in accuracy between different precision to be small. With a better initialization on the activations (max_val=1.0) I expect 2 bits to converge decently too.

tangyuelm commented 4 years ago

Oh. I see. Thank you very much.