Open DamonsJ opened 8 months ago
做这个的原因是因为 目前的看到fuse_multiheadattention.cpp 中的attention mask都是attribute 按道理说attention mask应该是输入,因为有可能每次推理的mask都是不一样的
torch.unbind op_3 1 3 4 5 6 7 dim=0
这里的输出 5 没有使用到?
torch.unbind op_3 1 3 4 5 6 7 dim=0
这里的输出 5 没有使用到?
sorry
我改了一下:
class fuse_multiheadattention_pass_20 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 16
pnnx.Input input_0 0 1 0
pnnx.Input attn_mask 0 1 1
nn.Linear op_0 1 1 0 2 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape op_1 1 1 2 3 shape=(%batch,%size,3,%num_heads,%feat_per_head)
torch.permute op_2 1 1 3 4 dims=(2,0,3,1,4)
torch.unbind op_3 1 3 4 5 6 7 dim=0
torch.permute op_4 1 1 6 8 dims=(0,1,3,2)
torch.matmul op_5 2 1 5 8 9
pnnx.Expression op_6 2 1 9 1 10 expr=add(mul(@0,%inv_sqrt_embed_dim_per_head),@1)
F.softmax op_7 1 1 10 11 dim=-1
torch.matmul op_8 2 1 11 7 12
torch.permute op_9 1 1 12 13 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 13 14 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 14 15 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 15
)PNNXIR";
}
const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Input attn_mask 0 1 attn_mask
nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int embed_dim = captured_params.at("embed_dim").i;
const int qkv_out_features = captured_params.at("qkv_out_features").i;
const int num_heads = captured_params.at("num_heads").i;
const int feat_per_head = captured_params.at("feat_per_head").i;
const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f;
printf(" <=====> inv_sqrt_embed_dim_per_head %.3f\n", inv_sqrt_embed_dim_per_head);
printf(" <=====> embed_dim %3d qkv_out_features %3d num_heads %3d feat_per_head %3d\n", embed_dim,qkv_out_features,num_heads,feat_per_head);
if (qkv_out_features != embed_dim * 3)
return false;
if (embed_dim != num_heads * feat_per_head)
return false;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001))
return false;
return true;
}
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);
Operator* op = ops.at("attention");
const int embed_dim = captured_params.at("embed_dim").i;
const bool qkvbias = captured_params.at("qkvbias").b;
const bool outbias = captured_params.at("outbias").b;
const bool bias = qkvbias || outbias;
op->params["bias"] = bias;
op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight");
if (bias)
{
if (qkvbias)
{
op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias");
}
else
{
// init bias as zero
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].set_float32_data(std::vector<float>(embed_dim * 3, 0.f));
}
}
op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");
if (bias)
{
if (outbias)
{
op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
}
else
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
op->attrs["out_proj.bias"].shape = {embed_dim};
op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
}
}
Operator* op_attr = ops.at("attn_mask");
const int batch = captured_params.at("batch").i;
const int size = captured_params.at("size").i;
printf(" <=====>batch %2d size : %2d \n", batch,size);
// hack attn_mask shape
op_attr->attrs["data"].shape = {batch ,1, size, size};
// hack attn_mask value
std::vector<char>& data = op_attr->attrs["data"].data;
size_t len = data.size();
data.resize(len * batch);
for (int i = 1; i < batch; i++)
{
memcpy(&data[len * i], &data[0], len);
}
}
};
python中这样保存
class TSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, atten_mask):
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn = q.matmul(k.permute((0, 1, 3, 2)))
attn = attn * self.scale + atten_mask
attn = F.softmax(attn, dim=-1)
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.attention_0_0 = TSelfAttention(embed_dim=768, num_heads=12)
def forward(self, x, attention_mask):
a = self.attention_0_0(x,attention_mask)
return a
def test():
net = Model()
net.eval()
torch.manual_seed(0)
x = torch.rand(2, 128, 768)
attention_mask = torch.rand(2,1,128, 128)
r = net(x,attention_mask)
# export torchscript
mod = torch.jit.trace(net, (x,attention_mask))
mod.save("test_bert_fused.pt")
# torchscript to pnnx
import os
os.system("../build/src/pnnx test_bert_fused.pt inputshape=[2,128,768],[2,1,128,128]")
也是不对的
我应该怎么样调试才能正确的融合这种类型的算子呢?
经过调试发现这样是可以的:
class fuse_multiheadattention_pass_20 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 16
pnnx.Input input_0 0 1 0
pnnx.Input attn_mask 0 1 1
nn.Linear op_0 1 1 0 2 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape op_1 1 1 2 3 shape=(%batch,%size,3,%num_heads,%feat_per_head)
torch.permute op_2 1 1 3 4 dims=(2,0,3,1,4)
torch.unbind op_3 1 3 4 5 6 7 dim=0
torch.permute op_4 1 1 6 8 dims=(0,1,3,2)
torch.matmul op_5 2 1 5 8 9
pnnx.Expression op_6 2 1 9 1 10 expr=add(mul(@0,%inv_sqrt_embed_dim_per_head),@1)
F.softmax op_7 1 1 10 11 dim=-1
torch.matmul op_8 2 1 11 7 12
torch.permute op_9 1 1 12 13 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 13 14 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 14 15 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 15
)PNNXIR";
}
const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 0
pnnx.Input attn_mask 0 1 1
nn.MultiheadAttention attention 2 1 0 1 15 embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask
pnnx.Output output 1 0 15
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int embed_dim = captured_params.at("embed_dim").i;
const int qkv_out_features = captured_params.at("qkv_out_features").i;
const int num_heads = captured_params.at("num_heads").i;
const int feat_per_head = captured_params.at("feat_per_head").i;
const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f;
printf(" <=====> inv_sqrt_embed_dim_per_head %.3f\n", inv_sqrt_embed_dim_per_head);
printf(" <=====> embed_dim %3d qkv_out_features %3d num_heads %3d feat_per_head %3d\n", embed_dim,qkv_out_features,num_heads,feat_per_head);
if (qkv_out_features != embed_dim * 3)
return false;
if (embed_dim != num_heads * feat_per_head)
return false;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001))
return false;
return true;
}
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
printf(" ====> fuse_multiheadattention_pass_20 write \n");
GraphRewriterPass::write(ops, captured_params, captured_attrs);
for(auto &item : ops) {
printf(" ====> fuse_multiheadattention_pass_20 op : %s \n", item.first.c_str());
}
for(auto &item : captured_params) {
printf(" ====> fuse_multiheadattention_pass_20 param : %s \n", item.first.c_str());
}
Operator* op = ops.at("attention");
printf(" ====> fuse_multiheadattention_pass_20 write op : %p \n", op);
const int embed_dim = captured_params.at("embed_dim").i;
const bool qkvbias = captured_params.at("qkvbias").b;
const bool outbias = captured_params.at("outbias").b;
const bool bias = qkvbias || outbias;
op->params["bias"] = bias;
op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight");
if (bias)
{
if (qkvbias)
{
op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias");
}
else
{
// init bias as zero
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].set_float32_data(std::vector<float>(embed_dim * 3, 0.f));
}
}
op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");
if (bias)
{
if (outbias)
{
op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
}
else
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
op->attrs["out_proj.bias"].shape = {embed_dim};
op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
}
}
printf(" ====> fuse_multiheadattention_pass_20 write o\n");
const int batch = captured_params.at("batch").i;
const int size = captured_params.at("size").i;
printf(" <=====>batch %2d size : %2d \n", batch,size);
}
};
class TSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, atten_mask):
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn = q.matmul(k.permute((0, 1, 3, 2)))
attn = attn * self.scale + atten_mask
attn = F.softmax(attn, dim=-1)
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
return x
感谢大佬! 不知道后续是否可以支持 或者能否提MR
在tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp中新增一个Attention的融合逻辑:
其他没有代码变化,重新编译 运行测试用例:
会报错
看上去是因为没有融合的原因吗?