shenweichen / DeepCTR-Torch

【PyTorch】Easy-to-use,Modular and Extendible package of deep-learning based CTR models.
https://deepctr-torch.readthedocs.io/en/latest/index.html
Apache License 2.0
2.96k stars 698 forks source link

What is X in forward function of DeepFM? #234

Closed Jeriousman closed 2 years ago

Jeriousman commented 2 years ago

Describe the question(问题描述)

The DeepFM model has forward function which takes X as an input (It is not only for DeepFM but also other models). But What does X look like? As it is an input for forward function, I wanted to check the shape of X and what it is but DeepCTR takes keras fit, predict so the codes dont really show what you put as the input for DeepFM model. Can you tell me the input and pseudo-generate the input data? And where can I check X if Im missing how to check it?

zanshuxun commented 2 years ago

This is DeepCTR-Torch repository, and there is no keras fit/predict here. Do you mean DeepCTR repository?

Jeriousman commented 2 years ago

No it is clearly DeepCTR-Torch. It might not be Keras one but predict and fit are there. If you go to class BaseModel in models/basemodel.py, you can find

 def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoch=0, validation_split=0.,
            validation_data=None, shuffle=True, callbacks=None):

and

 def predict(self, x, batch_size=256):

But these are 'small x' not 'large X'. If you see class Linear(nn.Module) in /models/basemodel.py, you can see large X def forward(self, X, sparse_feat_refine_weight=None):. And I believe this is the same input as X in def forward(self, X): in models/deepfm.py. 화면 캡처 2022-03-05 142933

What I want to know is what the large X should be like. I am stuck in here for the moment. Can you please specify what the X is? the shape and what it is in an example in pseudo generated data with torch at least?

zanshuxun commented 2 years ago

It's not keras fit. It's DeepCTR-Torch fit, which is designed in https://github.com/shenweichen/DeepCTR-Torch/blob/master/deepctr_torch/models/basemodel.py#L136

The input X could be found here: https://github.com/shenweichen/DeepCTR-Torch/blob/master/deepctr_torch/models/basemodel.py#L244 , which is derived from train input.

Jeriousman commented 2 years ago

I will look it up and tell you about it! Thank you.