fwaris / BertTorchSharp

.Net interactive notebook to show how to create a BERT model in TorchSharp; load pre-trained weights; and retrain it for text classification
9 stars 2 forks source link

Load pre-trained weights from pytorch_model.bin #2

Open GeorgeS2019 opened 2 years ago

GeorgeS2019 commented 2 years ago

Note: The weights can also be downloaded from Hugging Face, however they are not easily extractable from languages other than Python

bert_uncased_L-2_H-128_A-2

pytorch_model.bin

@fwaris With the implementation of this PR it is now possible to extract pre-trained weights from pytorch_model.bin and save inTorchSharp with name_tensor and weight

GeorgeS2019 commented 2 years ago
converted bert.embeddings.word_embeddings.weight - 15627392 bytes
converted bert.embeddings.position_embeddings.weight - 262272 bytes
converted bert.embeddings.token_type_embeddings.weight - 1152 bytes
converted bert.embeddings.LayerNorm.weight - 624 bytes
converted bert.embeddings.LayerNorm.bias - 624 bytes
converted bert.encoder.layer.0.attention.self.query.weight - 65664 bytes
converted bert.encoder.layer.0.attention.self.query.bias - 624 bytes
converted bert.encoder.layer.0.attention.self.key.weight - 65664 bytes
converted bert.encoder.layer.0.attention.self.key.bias - 624 bytes
converted bert.encoder.layer.0.attention.self.value.weight - 65664 bytes
converted bert.encoder.layer.0.attention.self.value.bias - 624 bytes
converted bert.encoder.layer.0.attention.output.dense.weight - 65664 bytes
converted bert.encoder.layer.0.attention.output.dense.bias - 624 bytes
converted bert.encoder.layer.0.attention.output.LayerNorm.weight - 624 bytes
converted bert.encoder.layer.0.attention.output.LayerNorm.bias - 624 bytes
converted bert.encoder.layer.0.intermediate.dense.weight - 262272 bytes
converted bert.encoder.layer.0.intermediate.dense.bias - 2160 bytes
converted bert.encoder.layer.0.output.dense.weight - 262272 bytes
converted bert.encoder.layer.0.output.dense.bias - 624 bytes
converted bert.encoder.layer.0.output.LayerNorm.weight - 624 bytes
converted bert.encoder.layer.0.output.LayerNorm.bias - 624 bytes
converted bert.encoder.layer.1.attention.self.query.weight - 65664 bytes
converted bert.encoder.layer.1.attention.self.query.bias - 624 bytes
converted bert.encoder.layer.1.attention.self.key.weight - 65664 bytes
converted bert.encoder.layer.1.attention.self.key.bias - 624 bytes
converted bert.encoder.layer.1.attention.self.value.weight - 65664 bytes
converted bert.encoder.layer.1.attention.self.value.bias - 624 bytes
converted bert.encoder.layer.1.attention.output.dense.weight - 65664 bytes
converted bert.encoder.layer.1.attention.output.dense.bias - 624 bytes
converted bert.encoder.layer.1.attention.output.LayerNorm.weight - 624 bytes
converted bert.encoder.layer.1.attention.output.LayerNorm.bias - 624 bytes
converted bert.encoder.layer.1.intermediate.dense.weight - 262272 bytes
converted bert.encoder.layer.1.intermediate.dense.bias - 2160 bytes
converted bert.encoder.layer.1.output.dense.weight - 262272 bytes
converted bert.encoder.layer.1.output.dense.bias - 624 bytes
converted bert.encoder.layer.1.output.LayerNorm.weight - 624 bytes
converted bert.encoder.layer.1.output.LayerNorm.bias - 624 bytes
converted bert.pooler.dense.weight - 65664 bytes
converted bert.pooler.dense.bias - 624 bytes
converted cls.predictions.bias - 122200 bytes
converted cls.predictions.transform.dense.weight - 65664 bytes
converted cls.predictions.transform.dense.bias - 624 bytes
converted cls.predictions.transform.LayerNorm.weight - 624 bytes
converted cls.predictions.transform.LayerNorm.bias - 624 bytes
converted cls.predictions.decoder.weight - 15627392 bytes
converted cls.predictions.decoder.bias - 122200 bytes
converted cls.seq_relationship.weight - 1152 bytes
converted cls.seq_relationship.bias - 120 bytes

Compared that listed in NoteBook

0 | [ 128 ] | bert/embeddings/LayerNorm/beta -- | -- | -- 1 | [ 128 ] | bert/embeddings/LayerNorm/gamma 2 | [ 512, 128 ] | bert/embeddings/position_embeddings 3 | [ 2, 128 ] | bert/embeddings/token_type_embeddings 4 | [ 30522, 128 ] | bert/embeddings/word_embeddings 5 | [ 128 ] | bert/encoder/layer_0/attention/output/LayerNorm/beta 6 | [ 128 ] | bert/encoder/layer_0/attention/output/LayerNorm/gamma 7 | [ 128 ] | bert/encoder/layer_0/attention/output/dense/bias 8 | [ 128, 128 ] | bert/encoder/layer_0/attention/output/dense/kernel 9 | [ 128 ] | bert/encoder/layer_0/attention/self/key/bias 10 | [ 128, 128 ] | bert/encoder/layer_0/attention/self/key/kernel 11 | [ 128 ] | bert/encoder/layer_0/attention/self/query/bias 12 | [ 128, 128 ] | bert/encoder/layer_0/attention/self/query/kernel 13 | [ 128 ] | bert/encoder/layer_0/attention/self/value/bias 14 | [ 128, 128 ] | bert/encoder/layer_0/attention/self/value/kernel 15 | [ 512 ] | bert/encoder/layer_0/intermediate/dense/bias 16 | [ 128, 512 ] | bert/encoder/layer_0/intermediate/dense/kernel 17 | [ 128 ] | bert/encoder/layer_0/output/LayerNorm/beta 18 | [ 128 ] | bert/encoder/layer_0/output/LayerNorm/gamma 19 | [ 128 ] | bert/encoder/layer_0/output/dense/bias
GeorgeS2019 commented 2 years ago

48 files

15,627,392 bert.embeddings.word_embeddings.weight.npy
   262,272 bert.embeddings.position_embeddings.weight.npy
     1,152 bert.embeddings.token_type_embeddings.weight.npy
       640 bert.embeddings.LayerNorm.weight.npy
       640 bert.embeddings.LayerNorm.bias.npy
    65,664 bert.encoder.layer.0.attention.self.query.weight.npy
       640 bert.encoder.layer.0.attention.self.query.bias.npy
    65,664 bert.encoder.layer.0.attention.self.key.weight.npy
       640 bert.encoder.layer.0.attention.self.key.bias.npy
    65,664 bert.encoder.layer.0.attention.self.value.weight.npy
       640 bert.encoder.layer.0.attention.self.value.bias.npy
    65,664 bert.encoder.layer.0.attention.output.dense.weight.npy
       640 bert.encoder.layer.0.attention.output.dense.bias.npy
       640 bert.encoder.layer.0.attention.output.LayerNorm.weight.npy
       640 bert.encoder.layer.0.attention.output.LayerNorm.bias.npy
   262,272 bert.encoder.layer.0.intermediate.dense.weight.npy
     2,176 bert.encoder.layer.0.intermediate.dense.bias.npy
   262,272 bert.encoder.layer.0.output.dense.weight.npy
       640 bert.encoder.layer.0.output.dense.bias.npy
       640 bert.encoder.layer.0.output.LayerNorm.weight.npy
       640 bert.encoder.layer.0.output.LayerNorm.bias.npy
    65,664 bert.encoder.layer.1.attention.self.query.weight.npy
       640 bert.encoder.layer.1.attention.self.query.bias.npy
    65,664 bert.encoder.layer.1.attention.self.key.weight.npy
       640 bert.encoder.layer.1.attention.self.key.bias.npy
    65,664 bert.encoder.layer.1.attention.self.value.weight.npy
       640 bert.encoder.layer.1.attention.self.value.bias.npy
    65,664 bert.encoder.layer.1.attention.output.dense.weight.npy
       640 bert.encoder.layer.1.attention.output.dense.bias.npy
       640 bert.encoder.layer.1.attention.output.LayerNorm.weight.npy
       640 bert.encoder.layer.1.attention.output.LayerNorm.bias.npy
   262,272 bert.encoder.layer.1.intermediate.dense.weight.npy
     2,176 bert.encoder.layer.1.intermediate.dense.bias.npy
   262,272 bert.encoder.layer.1.output.dense.weight.npy
       640 bert.encoder.layer.1.output.dense.bias.npy
       640 bert.encoder.layer.1.output.LayerNorm.weight.npy
       640 bert.encoder.layer.1.output.LayerNorm.bias.npy
    65,664 bert.pooler.dense.weight.npy
       640 bert.pooler.dense.bias.npy
   122,216 cls.predictions.bias.npy
    65,664 cls.predictions.transform.dense.weight.npy
       640 cls.predictions.transform.dense.bias.npy
       640 cls.predictions.transform.LayerNorm.weight.npy
       640 cls.predictions.transform.LayerNorm.bias.npy
15,627,392 cls.predictions.decoder.weight.npy
   122,216 cls.predictions.decoder.bias.npy
fwaris commented 2 years ago

@GeorgeS2019 thanks for the update - good to know

Note that IF your model uses PyTorch/TorchSharp tranformer layer, e.g.

    let encoderLayer = torch.nn.TransformerEncoderLayer(HIDDEN, N_HEADS, MAX_POS_EMB, ATTN_DROPOUT_PROB, activation=ENCODER_ACTIVATION)
    let encoder = torch.nn.TransformerEncoder(encoderLayer, ENCODER_LAYERS)

THEN several of the BERT layers have to be package together to load the encoder weights correctly, e.g.

type PostProc = V | H | T | N

let postProc (ts:torch.Tensor list) = function
    | V -> torch.vstack(ResizeArray ts)
    | H -> torch.hstack(ResizeArray ts)
    | T -> ts.Head.T                  //Linear layer weights need to be transformed. See https://github.com/pytorch/pytorch/issues/2159
    | N -> ts.Head

let nameMap =
    [
        "encoder.layers.#.self_attn.in_proj_weight",["encoder/layer_#/attention/self/query/kernel"; 
                                                     "encoder/layer_#/attention/self/key/kernel";    
                                                     "encoder/layer_#/attention/self/value/kernel"],        V

        "encoder.layers.#.self_attn.in_proj_bias",  ["encoder/layer_#/attention/self/query/bias";
                                                     "encoder/layer_#/attention/self/key/bias"; 
                                                     "encoder/layer_#/attention/self/value/bias"],          H
...

In the nameMap 3-tuple list above, the 1st tuple corresponds to the torchsharp layer name and the 2nd tuple is a list of bert layers that should be concatenated together to match the PyTorch layer weights. Note PostProc V, H (3rd tuple) represents Vertical or Horizontal stacking, respectively.

If the weights are from TensorFlow then 'linear' layer weights have to be transformed first ( PostProc.T ). (See the full notebook for details). If the weights are from a PyTorch model then this transformation may not be needed.

I believe HuggingFace versions don't use the PyTorch 'transfomer' layer. They use base layers that taken together are equivalent to a tranformer layer.