motefly / DeepGBM

SIGKDD'2019: DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks
647 stars 135 forks source link

Model saving and data prediction #5

Closed liouxy closed 4 years ago

liouxy commented 5 years ago

Hello, I've noticed the train_DEEPGBM method will return three variables, return deepgbm_model, opt, metric but the main function didn't handle these outputs,

elif args.model == "deepgbm":
        num_data = dh.load_data(args.data+'_num')
        cate_data = dh.load_data(args.data+'_cate')
        # designed for faster cateNN
        cate_data = dh.trans_cate_data(cate_data)
        train_DEEPGBM(args, num_data, cate_data, plot_title)

Could you tell me how to save the model for inference after training? And I also want to know how to preprocess the data and how to feed the data to the saved model.

Thank you very much!

motefly commented 5 years ago

The model is based on PyTorch, please consider saving them by some API of PyTorch. For LightGBM model, see the docs of it. For preprocessing, please see https://github.com/motefly/DeepGBM/blob/master/experiments/preprocess/example.sh .