PaddlePaddle / PaddleSlim

PaddleSlim is an open-source library for deep model compression and architecture search.
https://paddleslim.readthedocs.io/zh_CN/latest/
Apache License 2.0
1.56k stars 345 forks source link

离线量化文本识别crnn模型报错 #719

Closed MissPenguin closed 9 months ago

MissPenguin commented 3 years ago

模型来自PaddleOCRhttps://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15 运行命令:python tools/quant.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml quant.py的主要代码:

def main(config, device, logger, vdl_writer):
    # build dataloader
    loader = build_dataloader(config, 'Train', device, logger)

    paddle.enable_static()
    place = paddle.CPUPlace()
    exe = paddle.static.Executor(place)
#     model_dir = 'inference/ch_ppocr_mobile_v2.0_det_infer/'
    model_dir = 'inference/ch_ppocr_mobile_v2.0_rec_infer/'
    paddleslim.quant.quant_post_static(
            executor=exe,
            model_dir=model_dir,
            model_filename='inference.pdmodel',
            params_filename='inference.pdiparams',
            quantize_model_path='quant_post_static_model',
            sample_generator=loader,
            batch_size=256,
            batch_nums=10)

if __name__ == '__main__':
    config, device, logger, vdl_writer = program.preprocess(is_train=True)
    main(config, device, logger, vdl_writer)

报错信息: image 多次运行报错信息会变,再贴一个: image

MissPenguin commented 3 years ago

paddle版本: paddlepaddle-gpu 2.0.0.post90
paddleslim 2.0.0

XGZhang11 commented 3 years ago

传入quant_post_static()的sample_generator参数用于构建dataloader:

        self._data_loader = io.DataLoader.from_generator(
            feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
        if self._sample_generator is not None:
            self._data_loader.set_sample_generator(
                self._sample_generator,
                batch_size=self._batch_size,
                drop_last=True,
                places=self._place)
        elif self._batch_generator is not None:
            self._data_loader.set_batch_generator(
                self._batch_generator, places=self._place)

可以将quant.py的对应代码修改为:

def sample_generator(loader):
    def __reader__():
        for indx, data in enumerate(loader):
            images = np.array(data[0])
            yield images
    return __reader__

def main(config, device, logger, vdl_writer):
    loader = build_dataloader(config, 'Train', device, logger)
    paddle.enable_static()
    place = paddle.CPUPlace()
    exe = paddle.static.Executor(place)
    model_dir = 'inference/ch_ppocr_mobile_v2.0_rec_infer/'
    paddleslim.quant.quant_post_static(
            executor=exe,
            model_dir=model_dir,
            model_filename='inference.pdmodel',
            params_filename='inference.pdiparams',
            quantize_model_path='quant_post_static_model',
            sample_generator=sample_generator(loader),
            #batch_size=256,
            batch_nums=10)