jsksxs360 / How-to-use-Transformers

Transformers 库快速入门教程
https://transformers.run/
Apache License 2.0
1.02k stars 131 forks source link

关于第三章: 注意力机制 实现的问题 #12

Open Chaochao2020 opened 12 months ago

Chaochao2020 commented 12 months ago

在下面的代码中, 我觉得应该表明为什么 Q, K, V 向量序列是等于 inputs_embeds 的, 我理解的是注意力机制中的 QKV 是 embedding 与 W_Q 和 W_K , W_V 这三个矩阵相乘得到的, 这三个矩阵也是超参数, 而下面的代码是好像默认 这三个矩阵是单位矩阵. `import torch from math import sqrt

Q = K = V = inputs_embeds dim_k = K.size(-1) scores = torch.bmm(Q, K.transpose(1,2)) / sqrt(dim_k) print(scores.size())`

此外 dim_k = K.size(-1) 和下面封装的函数中不一致, 上面的 dim_k = K.size(-1), 而下面的 dim_k = query.size(-1)

`import torch import torch.nn.functional as F from math import sqrt

def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None): dim_k = query.size(-1) scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k) if query_mask is not None and key_mask is not None: mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1)) if mask is not None: scores = scores.masked_fill(mask == 0, -float("inf")) weights = F.softmax(scores, dim=-1) return torch.bmm(weights, value)`

Melmaphother commented 6 months ago

在下面的代码中, 我觉得应该表明为什么 Q, K, V 向量序列是等于 inputs_embeds 的, 我理解的是注意力机制中的 QKV 是 embedding 与 W_Q 和 W_K , W_V 这三个矩阵相乘得到的, 这三个矩阵也是超参数, 而下面的代码是好像默认 这三个矩阵是单位矩阵. `import torch from math import sqrt

Q = K = V = inputs_embeds dim_k = K.size(-1) scores = torch.bmm(Q, K.transpose(1,2)) / sqrt(dim_k) print(scores.size())`

此外 dim_k = K.size(-1) 和下面封装的函数中不一致, 上面的 dim_k = K.size(-1), 而下面的 dim_k = query.size(-1)

`import torch import torch.nn.functional as F from math import sqrt

def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None): dim_k = query.size(-1) scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k) if query_mask is not None and key_mask is not None: mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1)) if mask is not None: scores = scores.masked_fill(mask == 0, -float("inf")) weights = F.softmax(scores, dim=-1) return torch.bmm(weights, value)`

  1. Because it is self-attention, Q=K=V
  2. If K.size(-1) != query.size(-1), how can the matrixes be multipled?
Chaochao2020 commented 6 months ago

谢谢