Closed MissPenguin closed 9 months ago
paddle版本:
paddlepaddle-gpu 2.0.0.post90
paddleslim 2.0.0
传入quant_post_static()的sample_generator参数用于构建dataloader:
self._data_loader = io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
if self._sample_generator is not None:
self._data_loader.set_sample_generator(
self._sample_generator,
batch_size=self._batch_size,
drop_last=True,
places=self._place)
elif self._batch_generator is not None:
self._data_loader.set_batch_generator(
self._batch_generator, places=self._place)
可以将quant.py的对应代码修改为:
def sample_generator(loader):
def __reader__():
for indx, data in enumerate(loader):
images = np.array(data[0])
yield images
return __reader__
def main(config, device, logger, vdl_writer):
loader = build_dataloader(config, 'Train', device, logger)
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
model_dir = 'inference/ch_ppocr_mobile_v2.0_rec_infer/'
paddleslim.quant.quant_post_static(
executor=exe,
model_dir=model_dir,
model_filename='inference.pdmodel',
params_filename='inference.pdiparams',
quantize_model_path='quant_post_static_model',
sample_generator=sample_generator(loader),
#batch_size=256,
batch_nums=10)
模型来自PaddleOCR:https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15 运行命令:python tools/quant.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml quant.py的主要代码:
报错信息: 多次运行报错信息会变,再贴一个: