pmixer / SASRec.pytorch

PyTorch(1.6+) implementation of https://github.com/kang205/SASRec
Apache License 2.0
331 stars 90 forks source link

log2feats函数中有疑惑 #28

Open toyoululu opened 1 year ago

toyoululu commented 1 year ago

attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev)) seqs的维度应该是(batch_size,seq_len,embedding)其中(tl, tl)怎么能保证batch_size=seq_len? seqs = torch.transpose(seqs, 0, 1)为什么要transpose呀 期待你的答复

pmixer commented 1 year ago

@toyoululu 两件事,第一件事,没有地方要求过 batch size = seq len,tl, tl 也不对其提供保证;第二件事,做 transpose 是 torch 的 mha 层要求时间维提到最前面。疑问最终的源头可能是对多头注意力层(mha)不熟悉,建议观看 https://www.bilibili.com/video/BV1J441137V6/

toyoululu commented 1 year ago

谢谢回答,我仔细看了看api和代码,发现没有任何问题,我之前自己理解错误了