Closed ocelot43 closed 4 years ago
modules文件夹下transformer.py中用于实现qkv矩阵的线性层,bias参数应该是False吧
class MultiHeadAttn(nn.Module): def __init__(self, d_model, n_head, dropout=0.1, scale=False): super().__init__() assert d_model%n_head==0 self.n_head = n_head self.qkv_linear = nn.Linear(d_model, 3*d_model) self.fc = nn.Linear(d_model, d_model)
在relative_transformer.py文件中qv的实现,bias参数是False的
class RelativeMultiHeadAttn(nn.Module): def __init__(self, d_model, n_head, dropout, r_w_bias=None, r_r_bias=None, scale=False): super().__init__() self.qv_linear = nn.Linear(d_model, d_model * 2, bias=False) self.n_head = n_head self.head_dim = d_model // n_head self.dropout_layer = nn.Dropout(dropout)
嗯,你说得对。谢谢,我修改一下~
modules文件夹下transformer.py中用于实现qkv矩阵的线性层,bias参数应该是False吧
在relative_transformer.py文件中qv的实现,bias参数是False的