Oneflow-Inc / libai

LiBai(李白): A Toolbox for Large-Scale Distributed Parallel Training
https://libai.readthedocs.io
Apache License 2.0
390 stars 55 forks source link

About Load HuggingFace Bert #205

Open xiezipeng-ML opened 2 years ago

xiezipeng-ML commented 2 years ago

用LiBai的Bert加载huggingface的权重对齐输出发现的一些问题,经过修改后可以与hugigngface输出对齐

参数结构对比,可以先看最下面两个库中Bert的参数结构:

LiBai的Bert与huggingface的Bert内部逻辑计算上不同,导致输出不对齐:

我修改后的,结果可以与huggingface的q、k、v对齐:

query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) query = query.view(query.size(0), query.size(1), self.num_heads, -1).permute(0, 2, 1, 3) key = key.view(key.size(0), key.size(1), self.num_heads, -1).permute(0, 2, 1, 3) value = value.view(value.size(0), value.size(1), self.num_heads, -1).permute(0, 2, 1, 3)

- 然后就是LiBai的`TransformerLayer`内部计算逻辑和**huggingface**的有些部分不一样,这里的不同同样导致了**LiBai**的输出无法与**huggingface**对齐:
```python
# 这里的计算不同导致之后的所有输出都不一致,比如MLP层接受的输入也不同了
#原始代码:
# https://github.com/Oneflow-Inc/libai/blob/main/libai/layers/transformer_layer.py#L176
hidden_states = hidden_states + attention_output

# 我修改后的:
hidden_states = layernorm_output + attention_output

也就是说LiBai的hidden_states是用self-attention层的结果attention_output加上TransformerLayer的输入得到的,
Bert中有12层TransformerLayer,第一层的TransformerLayer输入是Embedding层的输出。

但是huggingface中的hidden_states是用self-attention层的
结果attention_output加上TransformerLayer的输入经过一次LayerNorm得到的,
也就是说LiBai中的hidden_states没有经过LayerNorm就加到hidden_states里面了,看起来是不合理的。

修改过后的:

output = layernorm_output + mlp_output

也就是说LiBai的TransformerLayer层的最后输出是由mlp_output和layernorm_output求和, huggingface中这里是用layernorm_output来计算的

- 修改完上面的问题后,把**LiBai**的`Bert`中的`bias_gelu_fusion、bias_dropout_fusion、apply_query_key_layer_scaling`设置为`False`,然后我写了一个加载**huggingface**预训练模型的函数,加载之后**LiBai**的`Bert`使用**huggingface**的权重可以得到与**huggingface**的`Bert`一样的输出(设置相同的一句话作为输入)。

#### 先看LiBai中的Bert参数结构
```python
embeddings.vocab_embeddings.weight oneflow.Size([30522, 768])
embeddings.position_embeddings.weight oneflow.Size([512, 768])
embeddings.tokentype_embeddings.weight oneflow.Size([2, 768])

encoders.0.input_layernorm.weight oneflow.Size([768])
encoders.0.input_layernorm.bias oneflow.Size([768])

encoders.0.self_attention.query_key_value.weight oneflow.Size([768, 2304])
encoders.0.self_attention.query_key_value.bias oneflow.Size([2304])
encoders.0.self_attention.dense.weight oneflow.Size([768, 768])
encoders.0.self_attention.dense.bias oneflow.Size([768])

encoders.0.post_attention_layernorm.weight oneflow.Size([768])
encoders.0.post_attention_layernorm.bias oneflow.Size([768])

encoders.0.mlp.dense_h_to_4h.weight oneflow.Size([768, 3072])
encoders.0.mlp.dense_h_to_4h.bias oneflow.Size([3072])

encoders.0.mlp.dense_4h_to_h.weight oneflow.Size([3072, 768])
encoders.0.mlp.dense_4h_to_h.bias oneflow.Size([768])

encoders.1.input_layernorm.weight oneflow.Size([768])
encoders.1.input_layernorm.bias oneflow.Size([768])

encoders.1.self_attention.query_key_value.weight oneflow.Size([768, 2304])
encoders.1.self_attention.query_key_value.bias oneflow.Size([2304])
encoders.1.self_attention.dense.weight oneflow.Size([768, 768])
encoders.1.self_attention.dense.bias oneflow.Size([768])
encoders.1.post_attention_layernorm.weight oneflow.Size([768])
encoders.1.post_attention_layernorm.bias oneflow.Size([768])
encoders.1.mlp.dense_h_to_4h.weight oneflow.Size([768, 3072])
encoders.1.mlp.dense_h_to_4h.bias oneflow.Size([3072])
encoders.1.mlp.dense_4h_to_h.weight oneflow.Size([3072, 768])
encoders.1.mlp.dense_4h_to_h.bias oneflow.Size([768])

final_layernorm.weight oneflow.Size([768])
final_layernorm.bias oneflow.Size([768])
pooler.dense.weight oneflow.Size([768, 768])
pooler.dense.bias oneflow.Size([768])

再看一下huggingface的参数结构

bert.embeddings.word_embeddings.weight torch.Size([30522, 768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.LayerNorm.gamma torch.Size([768])
bert.embeddings.LayerNorm.beta torch.Size([768])

bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.beta torch.Size([768])

bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072])

bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.0.output.dense.bias torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.beta torch.Size([768])

bert.encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.query.bias torch.Size([768])
bert.encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.key.bias torch.Size([768])
bert.encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.value.bias torch.Size([768])
bert.encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.beta torch.Size([768])
bert.encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.1.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.1.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.1.output.dense.bias torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.beta torch.Size([768])

bert.pooler.dense.weight torch.Size([768, 768])
bert.pooler.dense.bias torch.Size([768])
L1aoXingyu commented 2 years ago

关于 layernorm 位置的问题可以快速回复一下

我们参考的是 megatron 的代码实现,关于残差的位置在 megatron lm 的 paper 里面有写这样一段话

We further investigated this behavior and empirically demonstrated that rearranging the order of the layer normalization and the residual connections as shown in Figure 7 is critical to enable the scaling of the BERT-style models beyond BERT-Large. The architecture (b) in Figure 7 eliminates instabilities observed using the original BERT architecture in (a) and also has a lower training loss.

image

所以 libai 里面的 TransformerLayer 的位置和原始的 bert 是有所不同的. @xiezipeng-ML

xiezipeng-ML commented 2 years ago

关于 layernorm 位置的问题可以快速回复一下

我们参考的是 megatron 的代码实现,关于残差的位置在 megatron lm 的 paper 里面有写这样一段话

We further investigated this behavior and empirically demonstrated that rearranging the order of the layer normalization and the residual connections as shown in Figure 7 is critical to enable the scaling of the BERT-style models beyond BERT-Large. The architecture (b) in Figure 7 eliminates instabilities observed using the original BERT architecture in (a) and also has a lower training loss.

image

所以 libai 里面的 TransformerLayer 的位置和原始的 bert 是有所不同的. @xiezipeng-ML

LayerNorm位置确实是没问题的,是正常运算的。

L1aoXingyu commented 2 years ago

关于 qkv 计算的部分,我们也是参考下面 megatron 的代码,不过 libai 里面没有把 sequence 放到最前面,对应起来流程就是

libai:
                 view                  permute                split
[b, sq, (np*3*hn) -->  [b, sq, np, 3*hn] --> [b, np, sq, 3*hn] --> [b, np, sq, hn]

huggingface:
                  split                 view                permute
[b, sq, (np*3*hn)] --> [b, sq, (np*hn)] --> [b, sq, np, hn] --> [b, np, sq, hn]

https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/model/transformer.py#L221-L232

rentainhe commented 2 years ago

这个issue特别好,感觉可以单独整理出来作为一个常见问题模块,或者是Advanced Tutorials

xiezipeng-ML commented 2 years ago

image

记录了一下两种qkv的计算方法产生不同sbp的问题, @CPFLAME

L1aoXingyu commented 2 years ago

我们推导了一下,发现对齐 huggingface 的写法会导致之前推导的 sbp 出现问题,因为 huggingface 的写法是先做 chunk,而且 chunk 的维度刚好的 sbp.split,这样切完中间隐含了一次通信开销,所以我们觉得这样做可能会带了更多别的问题,你考虑用之前开杰提供的方案试试呢。

https://github.com/Oneflow-Inc/libai/issues/146#issuecomment-1054953921

xiezipeng-ML commented 2 years ago

我们推导了一下,发现对齐 huggingface 的写法会导致之前推导的 sbp 出现问题,因为 huggingface 的写法是先做 chunk,而且 chunk 的维度刚好的 sbp.split,这样切完中间隐含了一次通信开销,所以我们觉得这样做可能会带了更多别的问题,你考虑用之前开杰提供的方案试试呢。

#146 (comment)

好的星宇,我看怎么可以正确加载权重后能够对齐

xiezipeng-ML commented 2 years ago

不改变模型,改变wieght的加载能得到相同的结果

由于我们已经证明了libai的qkv计算的正确性(换成huggingface的qkv计算导致模型并行时sbp会出问题,目前的解决办法只有直接进行to_global来解决这个问题,而且不知道会不会造成别的问题,也就是说libai中的整套模型的sbp方案是配好的,换成别的计算方式有问题),所以这里考虑用不同的weight加载方式。

两种qkv计算方式:

# LiBai中的qkv计算方式:
# query_key_value:[batch_size, seq_len, 3*hidden_size]
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)  #(a) 
query_key_value = query_key_value.permute(0, 2, 1, 3)                                #(b)
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)                    #(c)

# huggingface中的qkv计算方式:
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
query = query.view(query.size(0), query.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
key = key.view(key.size(0), key.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
value = value.view(value.size(0), value.size(1), self.num_heads, -1).permute(0, 2, 1, 3)

首先解释一下为什么 LiBaiMultiheadAttention 中的 qkv 计算方式加载 huggingfaceweight 后无法得到相同的结果:

解决思路:

import torch
import torch.nn.functional as F

bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size

x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
# bias = torch.rand(2304)

# my method for weight------------------------------
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0)            # [4, 3, head_size, hidden_size]
weight_q = weight_q.view(-1, head_size, hidden_size)                     # [12, head_size, hidden_size]
weight_k = weight_k.view(-1, head_size, hidden_size)
weight_v = weight_v.view(-1, head_size, hidden_size)

weight_q = weight_q.unsqueeze(1)
weight_k = weight_k.unsqueeze(1)
weight_v = weight_v.unsqueeze(1)

weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1)     # [12*head_size, 3, hidden_size]
weight1 = weight1.view(-1, hidden_size)
# my method for weight end-----------------------------------------------------

weight2 = weight
qkv1 = F.linear(x, weight1, bias=None)
qkv2 = F.linear(x, weight2, bias=None)

# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)

print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)
rentainhe commented 2 years ago

bias解决方案

bsz = 32 seq_len = 5 num_heads = 12 head_size = 64 hidden_size = num_heads*head_size

x = torch.rand(bsz, seq_len, hidden_size) weight = torch.rand(hidden_size*3, hidden_size) bias = torch.rand(2304)

my method for weight------------------------------

weight1 = weight.view([num_heads, 3, head_size, hidden_size]) weight_temp = weight1

weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0) # [4, 3, head_size, hidden_size] weight_q = weight_q.view(-1, head_size, hidden_size) # [12, head_size, hidden_size] weight_k = weight_k.view(-1, head_size, hidden_size) weight_v = weight_v.view(-1, head_size, hidden_size)

weight_q = weight_q.unsqueeze(1) weight_k = weight_k.unsqueeze(1) weight_v = weight_v.unsqueeze(1)

weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1) # [12*head_size, 3, hidden_size] weight1 = weight1.view(-1, hidden_size)

my method for weight end-----------------------------------------------------

weight2 = weight

--------------convert bias-------------------------------

bias_ = bias.view(num_heads, 3, head_size) bias_q, bias_k, biasv = bias.chunk(3, dim=0) bias_q = bias_q.view(-1, head_size).unsqueeze(1) bias_k = bias_k.view(-1, head_size).unsqueeze(1) bias_v = bias_v.view(-1, head_size).unsqueeze(1) bias1 = torch.cat([bias_q, bias_k, bias_v], dim=1).view(-1)

-----------------------------------------------------------

qkv1 = F.linear(x, weight1, bias=bias1) # 2304, 768 qkv2 = F.linear(x, weight2, bias=bias)

pdb.set_trace()

libai------------------------------------------

qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size) qkv1 = qkv1.permute(0, 2, 1, 3) q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

huggingface------------------------------------------

q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1) q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2) k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2) v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)

print((q1==q2).all()) # tensor(True) print((k1==k2).all()) # tensor(True) print((v1==v2).all()) # tensor(True)


@xiezipeng-ML 

## 整理后的代码
```python
import torch
import torch.nn.functional as F

bsz = 32
seq_len = 5
num_heads = 12
head_size = 64
hidden_size = num_heads*head_size

x = torch.rand(bsz, seq_len, hidden_size)
weight = torch.rand(hidden_size*3, hidden_size)
bias = torch.rand(2304)

# convert weight and bias
weight1 = weight.view([num_heads, 3, head_size, hidden_size])
weight_q, weight_k, weight_v = weight1.chunk(chunks=3, dim=0)
weight_q = weight_q.view(-1, head_size, hidden_size).unsqueeze(1)
weight_k = weight_k.view(-1, head_size, hidden_size).unsqueeze(1)
weight_v = weight_v.view(-1, head_size, hidden_size).unsqueeze(1)
weight1 = torch.cat([weight_q, weight_k, weight_v], dim=1).view(-1, hidden_size)

bias_ = bias.view(num_heads, 3, head_size)
bias_q, bias_k, bias_v = bias_.chunk(3, dim=0)
bias_q = bias_q.view(-1, head_size).unsqueeze(1)
bias_k = bias_k.view(-1, head_size).unsqueeze(1)
bias_v = bias_v.view(-1, head_size).unsqueeze(1)
bias1 = torch.cat([bias_q, bias_k, bias_v], dim=1).view(-1)

weight2 = weight
bias2 = bias

qkv1 = F.linear(x, weight1, bias=bias1)
qkv2 = F.linear(x, weight2, bias=bias2)

# libai------------------------------------------
qkv1 = qkv1.view(bsz, seq_len, num_heads, 3*head_size)
qkv1 = qkv1.permute(0, 2, 1, 3)
q1, k1, v1 = torch.chunk(qkv1, chunks=3, dim=-1)

# huggingface------------------------------------------
q2, k2, v2 = torch.chunk(qkv2, chunks=3, dim=-1)
q2 = q2.view(q2.size(0), q2.size(1), num_heads, -1).transpose(1,2)
k2 = k2.view(k2.size(0), k2.size(1), num_heads, -1).transpose(1,2)
v2 = v2.view(v2.size(0), v2.size(1), num_heads, -1).transpose(1,2)

print((q1==q2).all())     # tensor(True)
print((k1==k2).all())     # tensor(True)
print((v1==v2).all())     # tensor(True)
xiezipeng-ML commented 2 years ago

bert的load_pretrain_weight后输出对齐了

import oneflow as flow
import libai
from libai.models import build_model
from libai.config import LazyCall
from load_huggingface_weight import load_huggingface_bert
from libai.utils import distributed as dist
import transformers
import torch
import numpy as np

input_ids = [[101, 1962, 2110, 739, 999, 1, 2, 3, 102]]
mask = [[1]*len(input_ids)]

# libai result
cfg = dict(
    vocab_size=21128,
    hidden_size=768,
    hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=512,
    num_tokentypes=2,
    add_pooling_layer=True,
    initializer_range=0.02,
    layernorm_eps=1e-12,
    bias_gelu_fusion=False, #
    bias_dropout_fusion=False,#
    scale_mask_softmax_fusion=False,
    apply_query_key_layer_scaling=False,#
    add_binary_head=True,
    amp_enabled=False,
    apply_residual_post_layernorm=True
)
bert_lib = build_model(LazyCall(libai.models.BertModel)(cfg=cfg))
load_huggingface_bert(bert_lib, './pretrain/pytorch_model.bin', cfg['hidden_size'], cfg['num_attention_heads'])
input_of = flow.tensor(input_ids, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
mask_of = flow.tensor(mask, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
bert_lib.eval()
last_hidden_state_of, pooler_output_of = bert_lib(input_of, mask_of)

# huggingface result
bert_hug = transformers.BertModel.from_pretrained('./pretrain')
bert_hug.eval()
input_pt = torch.tensor(input_ids)
mask_pt = torch.tensor(mask)
last_hidden_state_pt = bert_hug(input_pt, mask_pt).last_hidden_state 

res1 = last_hidden_state_of.detach().numpy()
res2 = last_hidden_state_pt.detach().numpy()
print(res1.sum())
print(res2.sum())