fastmachinelearning / hls4ml

Machine learning on FPGAs using HLS
https://fastmachinelearning.org/hls4ml
Apache License 2.0
1.19k stars 390 forks source link

Fix for 2D conv layers in the special case of io_parallel with full parallelization #760

Closed drankincms closed 1 year ago

drankincms commented 1 year ago

Description

This is a small fix for 2D conv layers when io_parallel is used and ParallelizationFactor is set to the output size. Currently, this causes a pipeline pragma to be ignored which results in no unrolling and very large latency/II. This fix adds explicit unroll pragmas in this case to restore the expected behavior.

Type of change

Tests

I have tested with the small model below:

# simplified CNN as example
nbits = 4
sym = 1

model = Sequential()
model.add(Input((9,16,1), name = 'input_student'))

model.add(QConv2D(1, (3,3), kernel_quantizer = quantized_bits(nbits,0,alpha = sym), bias_quantizer = quantized_bits(nbits,0,alpha = 1),  name = 'Student_Conv1a'))                                             
model.add( QActivation('quantized_relu('+str(nbits)+')'))
model.add(Flatten())
model.add(QDense(10, name='fc1',
                 kernel_quantizer=quantized_bits(nbits,0,alpha=1), bias_quantizer=quantized_bits(nbits,0,alpha=1)))
model.add(QActivation(activation=quantized_relu(nbits), name='relu1'))
model.add(QDense(10, name='fc2',
                 kernel_quantizer=quantized_bits(nbits,0,alpha=1), bias_quantizer=quantized_bits(nbits,0,alpha=1)))
model.add(QActivation(activation=quantized_relu(nbits), name='relu2'))
model.add(Dense(2, name='output'))

model.summary()
model.compile(optimizer="adam", loss=['mse'], metrics=['mse'])

import hls4ml

# HLS4ML: extraction of the model
config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_reuse_factor=1)
config['LayerName']['Student_Conv1a']['ParallelizationFactor'] = int(98/1)

print("-----------------------------------")
print_dict(config)
print("-----------------------------------")

# parameters for the conversion configuration 

cfg = hls4ml.converters.create_config(backend='Vivado')
cfg['IOType']     = 'io_parallel' #or io_stream
cfg['Strategy']   = 'latency'
cfg['HLSConfig']  = config
cfg['KerasModel'] = model
cfg['OutputDir']  = 'hls4ml_test_prj/'
cfg['XilinxPart'] = 'xcvu13p-flga2577-1-e'
cfg['Part'] = 'xcvu13p-flga2577-1-e'

hls_model = hls4ml.converters.keras_to_hls(cfg)

hls_model.compile()

It's not clear to me if this can be easily incorporated into a test. If this is necessary let me know and I will try a bit harder.

Checklist

vloncar commented 1 year ago

The construct

if (condition) {
    #pragma HLS ...
}

doesn't work on Vitis, so this fix will be ignored on that backend. Perhaps we can always unroll, and ensure the factor is correct if that is needed?

drankincms commented 1 year ago

Ah dang, ok. @vloncar you're suggesting to just have the unroll factor be set to whatever it would have been from the pipeline pragma such that it's invisible in all cases except this one?

drankincms commented 1 year ago

Ok, I have updated the PR to remove the

if (condition) {
    #pragma HLS ...
}

constructs. I also noticed that the pipeline pragma should have a rewind. If we would prefer this not be the case let me know and I can get rid of that change.

vloncar commented 1 year ago

I tested this and it doesn't break Vitis backend. I have one question: you force the unroll of the loops that iterate over n_in but not the ones that iterate over n_out, why? These are automatically unrolled in your example (since they are 1 iteration only), but I'm not sure they would be in general case.

vloncar commented 1 year ago

I checked this, and indeed we need to unroll the other loops as well. Also, it turns out we don't need the unroll factor, in case of RF=1 and PF=1, the unroll factor will be ignored, as it means complete unrolling. In cases of RF>1 the enclosing loop will be pipelined and this will control the unrolling, again ignoring the unroll factor.

drankincms commented 1 year ago

@vloncar I was going to test your suggestion to confirm it is needed but it looks like you beat me to it :) Thanks a lot!

One question about removing the unroll factors: I had a concern that I didn't check that fully unrolling always might cause long synthesis times in the case where n_partitions is large. Is this true? Or will the PIPELINE pragma still cause them all to be unrolled anyway and so it doesn't matter?

vloncar commented 1 year ago

In the tests of your model I didn't see this. The logs always have warning that the factor directive is ignored and all loops will be unrolled. There's no change in the synthesis results we just get rid of that warning.

drankincms commented 1 year ago

Ok, great, this was my hunch but I hadn't gotten to check. 👍 Thanks!