Tencent / ncnn

ncnn is a high-performance neural network inference framework optimized for the mobile platform
Other
20.47k stars 4.17k forks source link

在PNNX中新增Attention的融合逻辑失败 #5379

Open DamonsJ opened 8 months ago

DamonsJ commented 8 months ago

在tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp中新增一个Attention的融合逻辑:

class fuse_multiheadattention_pass_19 : public fuse_multiheadattention_pass
{
public:
     const char* match_pattern_graph() const
    {
        return R"PNNXIR(7767517
15 14
pnnx.Input              input_0     0 1 input
pnnx.Input              input_1     0 1 attn_mask
nn.Linear               op_0        1 1 input 2 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape          op_1        1 1 2 3 shape=(%batch,%size,3,%num_heads,%feat_per_head)
torch.permute           op_2        1 1 3 4 dims=(2,0,3,1,4)
torch.unbind            op_3        1 3 4 5 6 7 dim=0
torch.permute           op_4        1 1 6 8 dims=(0,1,3,2)
torch.matmul            op_5        2 1 4 8 9
pnnx.Expression         op_6        2 1 9 attn_mask 10 expr=add(div(@0,%sqrt_feat_per_head),@1)
F.softmax               op_7        1 1 10 11 dim=-1
torch.matmul            op_8        2 1 11 7 12
torch.permute           op_9        1 1 12 13 dims=(0,2,1,3)
Tensor.reshape          op_10       1 1 13 14 shape=(%batch,%size,%embed_dim)
nn.Linear               out_proj    1 1 14 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output             output      1 0 out
)PNNXIR";
    }
    const char* replace_pattern_graph() const
    {
        return R"PNNXIR(7767517
4 3
pnnx.Input              input_0     0 1 input
pnnx.Input              input_1     0 1 attn_mask
nn.MultiheadAttention   attention   2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False
pnnx.Output             output      1 0 out
)PNNXIR";
    }

    bool match(const std::map<std::string, Parameter>& captured_params) const
    {
        const int embed_dim = captured_params.at("embed_dim").i;
        const int qkv_out_features = captured_params.at("qkv_out_features").i;
        const int num_heads = captured_params.at("num_heads").i;
        const int feat_per_head = captured_params.at("feat_per_head").i;
        const float sqrt_feat_per_head = captured_params.at("sqrt_feat_per_head").f;

        if (qkv_out_features != embed_dim * 3)
            return false;

        if (embed_dim != num_heads * feat_per_head)
            return false;

        if (!NearlyEqual(sqrt_feat_per_head, sqrt(feat_per_head), 0.001))
            return false;

        return true;
    }

    void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
    {
        fuse_multiheadattention_pass::write(ops, captured_params, captured_attrs);

        const int size = captured_params.at("size").i;
        const int head = captured_params.at("num_heads").i;

        Operator* op_attr = ops.at("attn_mask");

        fprintf(stderr, "op_attr->attrs[data] type %d\n", op_attr->attrs["data"].type);

        // hack attn_mask shape
        op_attr->attrs["data"].shape = {1, head,size, size};

        // hack attn_mask value
        std::vector<char>& data = op_attr->attrs["data"].data;
        size_t len = data.size();
        data.resize(len * size);
        for (int i = 1; i < size; i++)
        {
            memcpy(&data[len * i], &data[0], len);
        }
    }
};

并在fuse_multiheadattention函数中增加
fuse_multiheadattention_pass_19 r;
pnnx_graph_rewrite(graph, &r, opindex);

其他没有代码变化,重新编译 运行测试用例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

from einops import rearrange
from typing import Any, Optional, Tuple, Union
import math

class TSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, atten_mask):
        _, N, C = x.shape
        qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = q.matmul(k.permute((0, 1, 3, 2)))
        attn = attn * self.scale

        attn = attn + atten_mask
        attn = F.softmax(attn, dim=-1)

        x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
        x = self.proj(x)
        return x

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.attention_0_0 = TSelfAttention(embed_dim=768, num_heads=12)

    def forward(self, x, attention_mask):
        a = self.attention_0_0(x,attention_mask)
        return a

def test():
    net = Model()
    net.eval()

    torch.manual_seed(0)
    x = torch.rand(2, 128, 768)
    attention_mask = torch.rand(2,12,128, 128)
    r = net(x, attention_mask)

    # export torchscript
    mod = torch.jit.trace(net, (x, attention_mask))
    mod.save("test_bert_fused.pt")

    # torchscript to pnnx
    import os
    os.system("../build/src/pnnx test_bert_fused.pt inputshape=[2,128,768],[2,12,128,128]")

    return True

if __name__ == "__main__":
    if test():
        exit(0)
    else:
        exit(1)

会报错

python3 test_bert_fused.py > test.log
pnnxparam = test_bert_fused.pnnx.param
pnnxbin = test_bert_fused.pnnx.bin
pnnxpy = test_bert_fused_pnnx.py
pnnxonnx = test_bert_fused.pnnx.onnx
ncnnparam = test_bert_fused.ncnn.param
ncnnbin = test_bert_fused.ncnn.bin
ncnnpy = test_bert_fused_ncnn.py
fp16 = 1
optlevel = 2
device = cpu
inputshape = [2,128,768]f32,[2,12,128,128]f32
inputshape2 = 
customop = 
moduleop = 
############# pass_level0
inline module = TSelfAttention
inline module = TSelfAttention

----------------

############# pass_level1
############# pass_level2
############# pass_level3
############# pass_level4
############# pass_level5
pnnx build without onnx-zero support, skip saving onnx
############# pass_ncnn
fallback batch axis 233 for operand 0
fallback batch axis 233 for operand 1
fallback batch axis 233 for operand 2
fallback batch axis 233 for operand 3
fallback batch axis 233 for operand 4
fallback batch axis 233 for operand 5
fallback batch axis 233 for operand 6
fallback batch axis 233 for operand 7
fallback batch axis 233 for operand 8
fallback batch axis 233 for operand 9
fallback batch axis 233 for operand 11
fallback batch axis 233 for operand 12
fallback batch axis 233 for operand 13
fallback batch axis 233 for operand 14
fallback batch axis 233 for operand 15
fallback batch axis 233 for operand pnnx_expr_8_mul(9,1.250000e-01)
fallback batch axis 233 for operand pnnx_expr_8_add(mul(9,1.250000e-01),1)
permute 5-rank tensor is not supported yet!

看上去是因为没有融合的原因吗?

DamonsJ commented 8 months ago

做这个的原因是因为 目前的看到fuse_multiheadattention.cpp 中的attention mask都是attribute 按道理说attention mask应该是输入,因为有可能每次推理的mask都是不一样的

nihui commented 8 months ago
torch.unbind            op_3        1 3 4 5 6 7 dim=0

这里的输出 5 没有使用到?

DamonsJ commented 8 months ago
torch.unbind            op_3        1 3 4 5 6 7 dim=0

这里的输出 5 没有使用到?

sorry

我改了一下:


class fuse_multiheadattention_pass_20 : public GraphRewriterPass
{
public:
    const char* match_pattern_graph() const
    {
        return R"PNNXIR(7767517
15 16
pnnx.Input              input_0     0 1 0
pnnx.Input              attn_mask   0 1 1
nn.Linear               op_0        1 1 0 2 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape          op_1        1 1 2 3 shape=(%batch,%size,3,%num_heads,%feat_per_head)
torch.permute           op_2        1 1 3 4 dims=(2,0,3,1,4)
torch.unbind            op_3        1 3 4 5 6 7 dim=0
torch.permute           op_4        1 1 6 8 dims=(0,1,3,2)
torch.matmul            op_5        2 1 5 8 9
pnnx.Expression         op_6        2 1 9 1 10 expr=add(mul(@0,%inv_sqrt_embed_dim_per_head),@1)
F.softmax               op_7        1 1 10 11 dim=-1
torch.matmul            op_8        2 1 11 7 12
torch.permute           op_9        1 1 12 13 dims=(0,2,1,3)
Tensor.reshape          op_10       1 1 13 14 shape=(%batch,%size,%embed_dim)
nn.Linear               out_proj    1 1 14 15 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output             output      1 0 15
)PNNXIR";
    }

    const char* replace_pattern_graph() const
    {
        return R"PNNXIR(7767517
4 3
pnnx.Input              input       0 1 input
pnnx.Input              attn_mask   0 1 attn_mask
nn.MultiheadAttention   attention   2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask
pnnx.Output             output      1 0 out
)PNNXIR";
    }

    bool match(const std::map<std::string, Parameter>& captured_params) const
    {
        const int embed_dim = captured_params.at("embed_dim").i;
        const int qkv_out_features = captured_params.at("qkv_out_features").i;
        const int num_heads = captured_params.at("num_heads").i;
        const int feat_per_head = captured_params.at("feat_per_head").i;
        const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f;
        printf(" <=====> inv_sqrt_embed_dim_per_head %.3f\n", inv_sqrt_embed_dim_per_head);
        printf(" <=====> embed_dim %3d qkv_out_features %3d num_heads %3d feat_per_head %3d\n", embed_dim,qkv_out_features,num_heads,feat_per_head);
        if (qkv_out_features != embed_dim * 3)
            return false;

        if (embed_dim != num_heads * feat_per_head)
            return false;

        if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001))
            return false;

        return true;
    }

    void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
    {
        GraphRewriterPass::write(ops, captured_params, captured_attrs);

        Operator* op = ops.at("attention");

        const int embed_dim = captured_params.at("embed_dim").i;
        const bool qkvbias = captured_params.at("qkvbias").b;
        const bool outbias = captured_params.at("outbias").b;
        const bool bias = qkvbias || outbias;

        op->params["bias"] = bias;

        op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight");
        if (bias)
        {
            if (qkvbias)
            {
                op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias");
            }
            else
            {
                // init bias as zero
                op->attrs["in_proj_bias"] = Attribute();
                op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type;
                op->attrs["in_proj_bias"].shape = {embed_dim * 3};
                op->attrs["in_proj_bias"].set_float32_data(std::vector<float>(embed_dim * 3, 0.f));
            }
        }

        op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");
        if (bias)
        {
            if (outbias)
            {
                op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
            }
            else
            {
                // init bias as zero
                op->attrs["out_proj.bias"] = Attribute();
                op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
                op->attrs["out_proj.bias"].shape = {embed_dim};
                op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
            }
        }

        Operator* op_attr = ops.at("attn_mask");
        const int batch = captured_params.at("batch").i;
        const int size = captured_params.at("size").i;
        printf(" <=====>batch  %2d size : %2d \n", batch,size);
        // hack attn_mask shape
        op_attr->attrs["data"].shape = {batch ,1, size, size};

        // hack attn_mask value
        std::vector<char>& data = op_attr->attrs["data"].data;
        size_t len = data.size();
        data.resize(len * batch);
        for (int i = 1; i < batch; i++)
        {
            memcpy(&data[len * i], &data[0], len);
        }

    }
};

python中这样保存

class TSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, atten_mask):
        _, N, C = x.shape
        qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = q.matmul(k.permute((0, 1, 3, 2)))
        attn = attn * self.scale + atten_mask
        attn = F.softmax(attn, dim=-1)
        x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
        x = self.proj(x)
        return x

 class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.attention_0_0 = TSelfAttention(embed_dim=768, num_heads=12)
    def forward(self, x, attention_mask):
        a = self.attention_0_0(x,attention_mask)
        return a

def test():
    net = Model()
    net.eval()

    torch.manual_seed(0)
    x = torch.rand(2, 128, 768)
    attention_mask = torch.rand(2,1,128, 128)
    r = net(x,attention_mask)

    # export torchscript
    mod = torch.jit.trace(net, (x,attention_mask))
    mod.save("test_bert_fused.pt")

    # torchscript to pnnx
    import os
    os.system("../build/src/pnnx test_bert_fused.pt inputshape=[2,128,768],[2,1,128,128]")              

也是不对的

DamonsJ commented 8 months ago

我应该怎么样调试才能正确的融合这种类型的算子呢?

DamonsJ commented 8 months ago

经过调试发现这样是可以的:

class fuse_multiheadattention_pass_20 : public GraphRewriterPass
{
public:
    const char* match_pattern_graph() const
    {
        return R"PNNXIR(7767517
15 16
pnnx.Input              input_0     0 1 0
pnnx.Input              attn_mask   0 1 1
nn.Linear               op_0        1 1 0 2 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape          op_1        1 1 2 3 shape=(%batch,%size,3,%num_heads,%feat_per_head)
torch.permute           op_2        1 1 3 4 dims=(2,0,3,1,4)
torch.unbind            op_3        1 3 4 5 6 7 dim=0
torch.permute           op_4        1 1 6 8 dims=(0,1,3,2)
torch.matmul            op_5        2 1 5 8 9
pnnx.Expression         op_6        2 1 9 1 10 expr=add(mul(@0,%inv_sqrt_embed_dim_per_head),@1)
F.softmax               op_7        1 1 10 11 dim=-1
torch.matmul            op_8        2 1 11 7 12
torch.permute           op_9        1 1 12 13 dims=(0,2,1,3)
Tensor.reshape          op_10       1 1 13 14 shape=(%batch,%size,%embed_dim)
nn.Linear               out_proj    1 1 14 15 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output             output      1 0 15
)PNNXIR";
    }

    const char* replace_pattern_graph() const
    {
        return R"PNNXIR(7767517
4 3
pnnx.Input              input       0 1 0
pnnx.Input              attn_mask   0 1 1
nn.MultiheadAttention   attention   2 1 0 1 15 embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask
pnnx.Output             output      1 0 15
)PNNXIR";
    }

    bool match(const std::map<std::string, Parameter>& captured_params) const
    {
        const int embed_dim = captured_params.at("embed_dim").i;
        const int qkv_out_features = captured_params.at("qkv_out_features").i;
        const int num_heads = captured_params.at("num_heads").i;
        const int feat_per_head = captured_params.at("feat_per_head").i;
        const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f;
        printf(" <=====> inv_sqrt_embed_dim_per_head %.3f\n", inv_sqrt_embed_dim_per_head);
        printf(" <=====> embed_dim %3d qkv_out_features %3d num_heads %3d feat_per_head %3d\n", embed_dim,qkv_out_features,num_heads,feat_per_head);
        if (qkv_out_features != embed_dim * 3)
            return false;

        if (embed_dim != num_heads * feat_per_head)
            return false;

        if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001))
            return false;

        return true;
    }

    void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
    {
        printf(" ====> fuse_multiheadattention_pass_20 write \n");
        GraphRewriterPass::write(ops, captured_params, captured_attrs);

        for(auto &item : ops) {
            printf(" ====> fuse_multiheadattention_pass_20  op : %s \n", item.first.c_str());
        }

        for(auto &item : captured_params) {
            printf(" ====> fuse_multiheadattention_pass_20  param : %s \n", item.first.c_str());
        }

        Operator* op = ops.at("attention");
        printf(" ====> fuse_multiheadattention_pass_20 write op : %p \n", op);

        const int embed_dim = captured_params.at("embed_dim").i;
        const bool qkvbias = captured_params.at("qkvbias").b;
        const bool outbias = captured_params.at("outbias").b;
        const bool bias = qkvbias || outbias;

        op->params["bias"] = bias;

        op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight");
        if (bias)
        {
            if (qkvbias)
            {
                op->attrs["in_proj_bias"] = captured_attrs.at("op_0.bias");
            }
            else
            {
                // init bias as zero
                op->attrs["in_proj_bias"] = Attribute();
                op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type;
                op->attrs["in_proj_bias"].shape = {embed_dim * 3};
                op->attrs["in_proj_bias"].set_float32_data(std::vector<float>(embed_dim * 3, 0.f));
            }
        }

        op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");
        if (bias)
        {
            if (outbias)
            {
                op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias");
            }
            else
            {
                // init bias as zero
                op->attrs["out_proj.bias"] = Attribute();
                op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
                op->attrs["out_proj.bias"].shape = {embed_dim};
                op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
            }
        }
        printf(" ====> fuse_multiheadattention_pass_20 write o\n");

        const int batch = captured_params.at("batch").i;
        const int size = captured_params.at("size").i;
        printf(" <=====>batch  %2d size : %2d \n", batch,size);

    }
};

class TSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, atten_mask):
        _, N, C = x.shape
        qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = q.matmul(k.permute((0, 1, 3, 2)))
        attn = attn * self.scale + atten_mask
        attn = F.softmax(attn, dim=-1)
        x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
        x = self.proj(x)
        return x

感谢大佬! 不知道后续是否可以支持 或者能否提MR