shenweichen / DeepCTR-Torch

【PyTorch】Easy-to-use,Modular and Extendible package of deep-learning based CTR models.
https://deepctr-torch.readthedocs.io/en/latest/index.html
Apache License 2.0
3.02k stars 705 forks source link

如何存和读取train好的model? #230

Open SoulEvill opened 2 years ago

SoulEvill commented 2 years ago

Discussed in https://github.com/shenweichen/DeepCTR-Torch/discussions/197

Originally posted by **SoulEvill** July 30, 2021 首先非常感谢这个deepctr torch这个package 可以非常快速的试各种model 但是我在读取train好的model 使用predict这个function会报错 **Operating environment(运行环境):** python version 3.8.8 torch version 1.8.1 deepctr-torch version 0.2.7 请您参考Issue: Bug report模板给出复现环境及步骤: **Describe the bug(问题描述)** 使用读取存储的模型用predict这个function的时候会有error 具体的error 信息: NotImplementedError Traceback (most recent call last) in ----> 1 reload_model.predict(train_model_input) /databricks/python/lib/python3.8/site-packages/deepctr_torch/models/basemodel.py in predict(self, x, batch_size) 340 x = x_test[0].to(self.device).float() 341 --> 342 y_pred = model(x).cpu().data.numpy() # .squeeze() 343 pred_ans.append(y_pred) 344 /databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result) /databricks/python/lib/python3.8/site-packages/deepctr_torch/models/deepfm.py in forward(self, X) 76 77 if self.use_dnn: ---> 78 dnn_input = combined_dnn_input( 79 sparse_embedding_list, dense_value_list) 80 dnn_output = self.dnn(dnn_input) /databricks/python/lib/python3.8/site-packages/deepctr_torch/inputs.py in combined_dnn_input(sparse_embedding_list, dense_value_list) 136 return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1) 137 else: --> 138 raise NotImplementedError 139 140 NotImplementedError: **To Reproduce(复现步骤)** 跟着run_classification_criteo.py的example 在这之后加入save/load model的步骤 以下是在example code 新增加的那部分,其他都保持一致 torch.save(model, "test.h5") reload_model = torch.load("test.h5") reload_model.predict(train_model_input) ## 此处报错 **Additional context** 同时也试过 1. deepctr-torch==0.2.6 and torch==1.5.0 2. deepctr-torch==0.2.7 and torch==1.5.0 都是同样的报错
xuChenSJTU commented 2 years ago

The same issue, how to make evaluation from a saved model?
I can see that model.predict() only accepts the input x....maybe a new predict_from_model function is needed?

zanshuxun commented 2 years ago

@SoulEvill 1、我在python 3.8.8、torch 1.8.1、deepctr-torch 0.2.7下,未能复现出这个报错,是正常运行结束的,见下图。请您重新下载deepctr-torch重跑一下试试 image

2、combined_dnn_input方法中,当sparse_embedding_listdense_value_list均为空时,才会抛出NotImplementedError。请确认是否改动了run_classification_criteo.py中的feature columns导致这两部分均为空 image

zanshuxun commented 2 years ago

The same issue, how to make evaluation from a saved model? I can see that model.predict() only accepts the input x....maybe a new predict_from_model function is needed?

What do you mean by a predict_from_model function? When you use model.predict(), the model should be your saved model.