YCG09 / chinese_ocr

CTPN + DenseNet + CTC based end-to-end Chinese OCR implemented using tensorflow and keras
Apache License 2.0
2.73k stars 1.08k forks source link

模型量化问题 #374

Open yoummiegao opened 1 year ago

yoummiegao commented 1 year ago

由于希望模型部署到端侧非GPU设备,需要将模型转化为int8的tflite模型,采用keras默认的QAT 量化意识训练实现方式。

    input = KL.Input(shape=(args.height, None, 3), name='input')
    labels = KL.Input(name='labels', shape=[None], dtype='int64')
    # labels = KL.Input(name='labels', shape=[None], dtype='float32')
    # labels = KL.Input(name='labels', shape=[1,1], dtype='float32')
    input_length = KL.Input(name='input_length', shape=[1], dtype='int64')
    label_length = KL.Input(name='label_length', shape=[1], dtype='int64')
    y_predict = densenet(input=input, num_classes=nclass)
    basemodel = Model(inputs=input, outputs=y_predict)
    basemodel.summary()
    loss_out = KL.Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_predict, labels, input_length, label_length])
    model = Model(inputs=[input, labels, input_length, label_length], outputs=loss_out)
    model.summary()
    # model.compile(optimizer='adam', loss='categorical_crossentropy') #need to further check
    # this is automatically done by inner built-in calculation
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy'])
    # model.compile(loss={'ctc':loss_out}, optimizer='adam', metrics=['accuracy'])
    # customcallback = CustomCallback()
    init_epoch = 0
    if args.restore != None:
        restore_path=''
        if args.restore_path is not None:
            restore_path = args.restore_path
        else:
            restore_path = chkp_dir+'/chkp-{:04}.chkp'.format(args.restore)
        print('Restore from epoch: ', args.restore, restore_path)
        if restore_path.endswith('h5'):
            model.load_weights(restore_path, by_name=True, skip_mismatch=True)
        else:
            model.load_weights(restore_path)
        init_epoch = int(args.restore)
    quantize_model = tfmot.quantization.keras.quantize_model
    q_aware_model = quantize_model(model)
    # q_aware_model = quantize_model(basemodel)
    q_aware_model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy'])
    if args.restore_qat != None:
        if args.restore_qat_path is not None:
            restore_qat_path = args.restore_qat_path
        else:
            restore_qat_path = chkp_qat_dir+'/chkp-{:04}.chkp'.format(args.restore_qat)
        print('Restore from epoch: ', restore_qat_path)
        q_aware_model.load_weights(restore_qat_path)
        # model.load_model(restore_path)
        # model = tf.keras.models.load_model(save_path)
        init_epoch_qat = int(args.restore_qat)
        q_aware_model.summary()
        print('-----------Start qat training-----------')
        q_aware_model.fit(
            train_dataset,
            epochs=args.epoch_qat,
            validation_data=val_dataset,
            initial_epoch = init_epoch,
            callbacks=[
                tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1, update_freq=500, write_images=True),
                tf.keras.callbacks.ModelCheckpoint(filepath=chkp_path, save_weights_only=True, verbose=1),
                tf.keras.callbacks.LearningRateScheduler(tf.keras.optimizers.schedules.PiecewiseConstantDecay(lr_boundaries, lr_values), verbose = 1)
                # customcallback
                # tf.keras.callbacks.LearningRateScheduler(lambda epoch: float(learning_rate[epoch]))
        ])

但运行QAT训练就会报以下错误,大家有遇到类似问题的吗?针对报错提示,有什么解决办法吗?非常感谢!

File “/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/ops.py”, line 1939, in _create_c_op
    raise ValueError(e.message)
ValueError: Exception encountered when calling layer “batch_normalization” (type BatchNormalization).
Shape must be rank 4 but is rank 7 for ‘{{node batch_normalization/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format=“NHWC”, epsilon=1.1e-05, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization/ReadVariableOp, batch_normalization/ReadVariableOp_1, batch_normalization/FusedBatchNormV3/ReadVariableOp, batch_normalization/FusedBatchNormV3/ReadVariableOp_1)’ with input shapes: [1,1,1,?,16,?,64], [64], [64], [64], [64].
Call arguments received:
  • inputs=tf.Tensor(shape=(1, 1, 1, None, 16, None, 64), dtype=float32)
  • training=None