Closed songsh closed 4 years ago
如题,不需要每次要用model ,都重新训练.
可以直接在Trainer中传入SAVE_PATH进行模型保存
SAVE_PATH = './data/model'
trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device, save_path=SAVE_PATH, callbacks=[VisdomCallBack(plotter_loss)])
或者
from fastNLP.io.model_io import ModelSaver
from fastNLP.io.model_io import ModelLoader
Trainer中的保存是直接使用了torch.save(model),这种方式有一些缺点,例如pytorch版本更迭的时候可能报错;机器之间迁移可能会遇到问题。一个更加好的保存方式是torch.save(model.state_dict()), 这种方式只保留model的参数,其优势是兼容性强,缺点是以后每次load之前,得先用一模一样的参数把model先初始化出来,然后使用model.load_state_dict(torch.load('xxxx_file'))这样加载。目前我们正在修改部分代码以更好支持第二种方式以方便大家实际将model用于部署环境。
如题,不需要每次要用model ,都重新训练.