def bucket_embedding(self, width, device):
em = [ix for ix, v in enumerate(self.bucket_bins) if width >= v][0]
return self.span_width_embedding(torch.LongTensor([em]).to(device))
感谢指出错误,width下标取错了
em = [ix for ix, v in enumerate(self.bucket_bins) if width >= v][0]
=>
em = [ix for ix, v in enumerate(self.bucket_bins) if width >= v][-1]
self.bucket_bins = [0, 1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32, 63, 64]
self.span_width_embedding = nn.Embedding(len(self.bucket_bins), span_width_embedding_dim)
def bucket_embedding(self, width, device): em = [ix for ix, v in enumerate(self.bucket_bins) if width >= v][0] return self.span_width_embedding(torch.LongTensor([em]).to(device))
model.py文件下的bucket_embedding函数,应当是对于Span的长度feature进行编码的。但是em得到的数值难道不是永远是0吗,这样也没有用到长度特征呀。我想知道bucket_embedding到底是要实现什么功能,谢谢各位。