chengkai-liu / Mamba4Rec

[RelKD'24] Mamba4Rec: Towards Efficient Sequential Recommendation with Selective State Space Models
https://arxiv.org/abs/2403.03900
MIT License
74 stars 2 forks source link

About the batch_size of item_emb #6

Closed TBI805 closed 4 months ago

TBI805 commented 5 months ago

Hello! Thanks for your contributions!

In mamba4rec.py, I would like to know that why the batch_size of [item_emb] in line 63 is 1 when I debug on this point?

Actually, the [train_batch_size] in [config.yaml] is 2048, I‘m confused about this.

Would you please explain this matter to me in detail?

Thank you very much!

Snipaste_2024-04-09_21-05-59 Snipaste_2024-04-09_21-07-54
chengkai-liu commented 5 months ago

Because the get_flops function calls the predict function, when you set a breakpoint in the forward function, the first stop is when the get_flops function (line 45, run.py) is entered. At this time, the model has not yet entered the training phase. If you continue to run, from (line 52, run.py), the next time at the forward breakpoint, the batch size will be 2048.

TBI805 commented 5 months ago

OK, Thanks for your response!

But where does the [batch_size] dimension (i.e., 1) of the input data [item_emb] come from? Why is it 1 rather than 2 or 3? How you get this dimension from model loading and initialization?

chengkai-liu commented 5 months ago

You can set a breakpoint in get_flops, or look at the specific implementation of get_flops. You can also directly delete the get_flops function, which will not affect the result.

TBI805 commented 5 months ago

I know that when the mamba deal with the input, the input size must be [B, L ,D], I understand where the L, D dimensions come from, but I'm confused as to where the B dimension comes from before the forward function.

chengkai-liu commented 5 months ago

For the data loading and batching process, you need to read the implementation details of RecBole.

TBI805 commented 5 months ago

OK, I will check the implementation details, thanks!