PaddlePaddle / PaDiff

Paddle Automatically Diff Precision Toolkits.
46 stars 13 forks source link

MultiheadAttention初始化失败 #97

Closed ToddBear closed 1 year ago

ToddBear commented 1 year ago

在样例代码上加入MultiheadAttention,尝试进行参数值复制,但失败

版本: paddlepaddle-gpu == 2.4.2 torch == 1.12.0+cu102

代码

import paddle
import torch
from padiff import assign_weight, create_model, auto_diff
import torch
import paddle
import paddle
import signal
import os
from padiff import add_special_init

class SimpleModule(torch.nn.Module):
  def __init__(self):
      super(SimpleModule, self).__init__()
      self.linear1 = torch.nn.Linear(100, 10)
      self.attention = torch.nn.MultiheadAttention(64, 8, dropout=0.1, batch_first=True)
  def forward(self, x):
      x = self.linear1(x)
      return x

class SimpleLayer(paddle.nn.Layer):
  def __init__(self):
      super(SimpleLayer, self).__init__()
      self.linear1 = paddle.nn.Linear(100, 10)
      self.attention = paddle.nn.MultiHeadAttention(64, 8, dropout=0.1)
  def forward(self, x):
      x = self.linear1(x)
      return x

module = create_model(SimpleModule())
module.auto_layer_map("base")
layer = create_model(SimpleLayer())
layer.auto_layer_map("raw")

assign_weight(module, layer)

报错: RuntimeError: Error occured when trying init weights, between: base_model: MultiheadAttention() SimpleModule.attention.in_proj_weight raw_model: Linear(in_features=64, out_features=64, dtype=float32) SimpleLayer.attention.q_proj.weight

模型架构日志文件: padiff_log/weight_init_SimpleModule.log:

SimpleModule
========================================
    SimpleModule
     |--- Linear
     +--- MultiheadAttention    <---  *** HERE ***
           +--- NonDynamicallyQuantizableLinear  (skip)

padiff_log/weight_init_SimpleLayer.log:

SimpleLayer
========================================
    SimpleLayer
     |--- Linear
     +--- MultiHeadAttention  (skip)
           |--- Linear    <---  *** HERE ***
           |--- Linear
           |--- Linear
           +--- Linear

请问应该如何修改呢?

feifei-111 commented 1 year ago

这里有bug,修复了,拉一下develop就可以