idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

layernorm eps is not copied properly when cloning HF_Bert #103

Open huu4ontocord opened 3 years ago

huu4ontocord commented 3 years ago

From the code (adapted from test_weight_mapper.py)

import torch
import torch.nn as nn

from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.weight_mapper import PytorchMapper, \
    HugginfaceBertEncoderMapper,  LongformerMapper
from transformers import BertConfig, BertModel

def load_fast_bert():
        bert = BertModel(BertConfig())
        encoder = TransformerEncoderBuilder.from_kwargs(
            n_layers=12,
            n_heads=12,
            query_dimensions=64,
            value_dimensions=64,
            feed_forward_dimensions=3072,
            attention_type="full",
            final_normalization=False,
            activation="gelu"
        ).get()

        encoder.load_state_dict(
            HugginfaceBertEncoderMapper().map(bert.encoder.state_dict())
        )
        return encoder, bert

encoder looks like this:

TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attention): AttentionLayer(
        (inner_attention): FullAttention(
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (query_projection): Linear(in_features=768, out_features=768, bias=True)
        (key_projection): Linear(in_features=768, out_features=768, bias=True)
        (value_projection): Linear(in_features=768, out_features=768, bias=True)
        (out_projection): Linear(in_features=768, out_features=768, bias=True)
      )
      (linear1): Linear(in_features=768, out_features=3072, bias=True)
      (linear2): Linear(in_features=3072, out_features=768, bias=True)
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )

bert looks like this:

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
huu4ontocord commented 3 years ago

Do you need to pass in the eps to TransformerEncoderLayer and TransformerDecoderLayer, and correspondingly pass in the eps from the constructor, and pass in via the builder and mapper as well?

huu4ontocord commented 3 years ago

On a related note, it would be cool to specify other norms, like scalenorm and to be able to configure the aciviation function to something beside's pytorch (like relu^2)