fastnlp / fastNLP

fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.
https://gitee.com/fastnlp/fastNLP
Apache License 2.0
3.06k stars 450 forks source link

训练完的model ,能保存吗,下次要使用时,再从文件加载 #302

Closed songsh closed 4 years ago

songsh commented 4 years ago

如题,不需要每次要用model ,都重新训练.

wuyue92tree commented 4 years ago

如题,不需要每次要用model ,都重新训练.

可以直接在Trainer中传入SAVE_PATH进行模型保存

SAVE_PATH = './data/model'
trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
                        dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device, save_path=SAVE_PATH, callbacks=[VisdomCallBack(plotter_loss)])

或者

模型保存

from fastNLP.io.model_io import ModelSaver

模型加载

from fastNLP.io.model_io import ModelLoader

yhcc commented 4 years ago

Trainer中的保存是直接使用了torch.save(model),这种方式有一些缺点,例如pytorch版本更迭的时候可能报错;机器之间迁移可能会遇到问题。一个更加好的保存方式是torch.save(model.state_dict()), 这种方式只保留model的参数,其优势是兼容性强,缺点是以后每次load之前,得先用一模一样的参数把model先初始化出来,然后使用model.load_state_dict(torch.load('xxxx_file'))这样加载。目前我们正在修改部分代码以更好支持第二种方式以方便大家实际将model用于部署环境。