Open yoummiegao opened 1 year ago
由于希望模型部署到端侧非GPU设备,需要将模型转化为int8的tflite模型,采用keras默认的QAT 量化意识训练实现方式。
input = KL.Input(shape=(args.height, None, 3), name='input') labels = KL.Input(name='labels', shape=[None], dtype='int64') # labels = KL.Input(name='labels', shape=[None], dtype='float32') # labels = KL.Input(name='labels', shape=[1,1], dtype='float32') input_length = KL.Input(name='input_length', shape=[1], dtype='int64') label_length = KL.Input(name='label_length', shape=[1], dtype='int64') y_predict = densenet(input=input, num_classes=nclass) basemodel = Model(inputs=input, outputs=y_predict) basemodel.summary() loss_out = KL.Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_predict, labels, input_length, label_length]) model = Model(inputs=[input, labels, input_length, label_length], outputs=loss_out) model.summary() # model.compile(optimizer='adam', loss='categorical_crossentropy') #need to further check # this is automatically done by inner built-in calculation model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy']) # model.compile(loss={'ctc':loss_out}, optimizer='adam', metrics=['accuracy']) # customcallback = CustomCallback() init_epoch = 0 if args.restore != None: restore_path='' if args.restore_path is not None: restore_path = args.restore_path else: restore_path = chkp_dir+'/chkp-{:04}.chkp'.format(args.restore) print('Restore from epoch: ', args.restore, restore_path) if restore_path.endswith('h5'): model.load_weights(restore_path, by_name=True, skip_mismatch=True) else: model.load_weights(restore_path) init_epoch = int(args.restore) quantize_model = tfmot.quantization.keras.quantize_model q_aware_model = quantize_model(model) # q_aware_model = quantize_model(basemodel) q_aware_model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy']) if args.restore_qat != None: if args.restore_qat_path is not None: restore_qat_path = args.restore_qat_path else: restore_qat_path = chkp_qat_dir+'/chkp-{:04}.chkp'.format(args.restore_qat) print('Restore from epoch: ', restore_qat_path) q_aware_model.load_weights(restore_qat_path) # model.load_model(restore_path) # model = tf.keras.models.load_model(save_path) init_epoch_qat = int(args.restore_qat) q_aware_model.summary() print('-----------Start qat training-----------') q_aware_model.fit( train_dataset, epochs=args.epoch_qat, validation_data=val_dataset, initial_epoch = init_epoch, callbacks=[ tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1, update_freq=500, write_images=True), tf.keras.callbacks.ModelCheckpoint(filepath=chkp_path, save_weights_only=True, verbose=1), tf.keras.callbacks.LearningRateScheduler(tf.keras.optimizers.schedules.PiecewiseConstantDecay(lr_boundaries, lr_values), verbose = 1) # customcallback # tf.keras.callbacks.LearningRateScheduler(lambda epoch: float(learning_rate[epoch])) ])
但运行QAT训练就会报以下错误,大家有遇到类似问题的吗?针对报错提示,有什么解决办法吗?非常感谢!
File “/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/ops.py”, line 1939, in _create_c_op raise ValueError(e.message) ValueError: Exception encountered when calling layer “batch_normalization” (type BatchNormalization). Shape must be rank 4 but is rank 7 for ‘{{node batch_normalization/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format=“NHWC”, epsilon=1.1e-05, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization/ReadVariableOp, batch_normalization/ReadVariableOp_1, batch_normalization/FusedBatchNormV3/ReadVariableOp, batch_normalization/FusedBatchNormV3/ReadVariableOp_1)’ with input shapes: [1,1,1,?,16,?,64], [64], [64], [64], [64]. Call arguments received: • inputs=tf.Tensor(shape=(1, 1, 1, None, 16, None, 64), dtype=float32) • training=None
由于希望模型部署到端侧非GPU设备,需要将模型转化为int8的tflite模型,采用keras默认的QAT 量化意识训练实现方式。
但运行QAT训练就会报以下错误,大家有遇到类似问题的吗?针对报错提示,有什么解决办法吗?非常感谢!