bytedance / lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
Other
3.17k stars 328 forks source link

What should I do if I want to use lightseq to accelerate MarianMTModel #422

Open Youggls opened 1 year ago

Youggls commented 1 year ago

Hi, I'm trying use lightseq to accelerate MarianMTModel seq2seq inference. This model is basically same with BartForConditionalGeneration.

There are two differences between MarianMTModel and BartForConditionalGeneration.

  1. MarianMTModel use swish activation function instead of gelu.
  2. MarianMTModel doesn't have layernorm_embedding module in Encoder and Decoder.

So, I edit hf_bart_export.py and try to export MarianMTModel. I commented out the call to the fill_pb_layer function, and change trg_emb_mapping_dict from

trg_emb_mapping_dict = OrderedDict(
    {
        "norm_scale": "layernorm_embedding weight",
        "norm_bias": "layernorm_embedding bias",
        "shared_bias": "final_logits_bias",
    }
)

to

trg_emb_mapping_dict = OrderedDict(
    {
        "shared_bias": "final_logits_bias",
    }
)

And then, I run hf_bart_export.py and I get the hdf5 type output without any error. However, I got an error when I tried to use lightseq.inference.Transformer to load model. I got the error as follows:

Parsing hdf5: ./opus-mt-en-fr.hdf5
loading 116 MB of embedding weight.
Finish loading src_emb_wei from host to device
loading 128 MB of embedding weight.
Finish loading trg_emb_wei from host to device
loading 72 MB of encoder weight.
Traceback (most recent call last):
  File "test.py", line 3, in <module>
    model = lsi.Transformer('./opus-mt-en-fr.hdf5', 128)
RuntimeError: Invalid shape 0. Wrong multihead_norm_scale_size !

What should I do if I want to use the MarianMTModel under lightseq?

Thanks!

Youggls commented 1 year ago

Here is my self export script, it's basicly same with hf_bart_export.py

"""
Export Hugging Face BART models to protobuf/hdf5 format.
"""
import os
from collections import OrderedDict
import tensorflow as tf
import h5py
import numpy as np
from operator import attrgetter
from lightseq.training.ops.pytorch.export import gather_token_embedding, fill_pb_layer
from export.proto.transformer_pb2 import Transformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from export.util import parse_args
# import pdb;pdb.set_trace()
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

"""
For the mapping dictionary: key is the value of the proto parameter, value is a powerful expression, each && split tensor name of the matching path or expression.

The sub-pattern of the path is separated by spaces, and the expression starts with a expression_. You can operate separately on each tensor and support multiple expressions. Multiple matching paths
and the expression will finally be concatenated on axis = -1.
"""
enc_layer_mapping_dict = OrderedDict(
    {
        "multihead_norm_scale": "self_attn_layer_norm weight",
        "multihead_norm_bias": "self_attn_layer_norm bias",
        "multihead_project_kernel_qkv": "self_attn q_proj weight&&self_attn k_proj weight&&self_attn v_proj weight&&expression_.transpose(0, 1)",
        "multihead_project_bias_qkv": "self_attn q_proj bias&&self_attn k_proj bias&&self_attn v_proj bias",
        "multihead_project_kernel_output": "self_attn out_proj weight&&expression_.transpose(0, 1)",
        "multihead_project_bias_output": "self_attn out_proj bias",
        "ffn_norm_scale": "final_layer_norm weight",
        "ffn_norm_bias": "final_layer_norm bias",
        "ffn_first_kernel": "fc1 weight&&expression_.transpose(0, 1)",
        "ffn_first_bias": "fc1 bias",
        "ffn_second_kernel": "fc2 weight&&expression_.transpose(0, 1)",
        "ffn_second_bias": "fc2 bias",
    }
)

dec_layer_mapping_dict = OrderedDict(
    {
        "self_norm_scale": "self_attn_layer_norm weight",
        "self_norm_bias": "self_attn_layer_norm bias",
        "self_project_kernel_qkv": "self_attn q_proj weight&&self_attn k_proj weight&&self_attn v_proj weight&&expression_.transpose(0, 1)",
        "self_project_bias_qkv": "self_attn q_proj bias&&self_attn k_proj bias&&self_attn v_proj bias",
        "self_project_kernel_output": "self_attn out_proj weight&&expression_.transpose(0, 1)",
        "self_project_bias_output": "self_attn out_proj bias",
        "encdec_norm_scale": "encoder_attn_layer_norm weight",
        "encdec_norm_bias": "encoder_attn_layer_norm bias",
        "encdec_project_kernel_q": "encoder_attn q_proj weight&&expression_.transpose(0, 1)",
        "encdec_project_bias_q": "encoder_attn q_proj bias",
        "encdec_project_kernel_output": "encoder_attn out_proj weight&&expression_.transpose(0, 1)",
        "encdec_project_bias_output": "encoder_attn out_proj bias",
        "ffn_norm_scale": "final_layer_norm weight",
        "ffn_norm_bias": "final_layer_norm bias",
        "ffn_first_kernel": "fc1 weight&&expression_.transpose(0, 1)",
        "ffn_first_bias": "fc1 bias",
        "ffn_second_kernel": "fc2 weight&&expression_.transpose(0, 1)",
        "ffn_second_bias": "fc2 bias",
    }
)

src_emb_mapping_dict = OrderedDict(
    {
        "norm_scale": "layernorm_embedding weight",
        "norm_bias": "layernorm_embedding bias",
    }
)

trg_emb_mapping_dict = OrderedDict(
    {
        # "norm_scale": "layernorm_embedding weight",
        # "norm_bias": "layernorm_embedding bias",
        "shared_bias": "final_logits_bias",
    }
)

def _get_encode_output_mapping_dict(dec_layer_num):
    encode_output_kernel_pattern = [
        "encoder_attn {0} k_proj weight&&encoder_attn {0} v_proj weight".format(ele)
        for ele in range(dec_layer_num)
    ]
    encode_output_bias_pattern = [
        "encoder_attn {0} k_proj bias&&encoder_attn {0} v_proj bias".format(ele)
        for ele in range(dec_layer_num)
    ]

    return {
        "encode_output_project_kernel_kv": "&&".join(
            encode_output_kernel_pattern + ["expression_.transpose(0, 1)"]
        ),
        "encode_output_project_bias_kv": "&&".join(encode_output_bias_pattern),
    }

def save_bart_proto_to_hdf5(transformer: Transformer, f: h5py.File):
    """Convert bart protobuf to hdf5 format to support larger weight."""
    MODEL_CONF_KEYS = [
        # model_conf
        "head_num",
        "beam_size",
        "extra_decode_length",
        "length_penalty",
        "src_padding_id",
        "trg_start_id",
        "diverse_lambda",
        "sampling_method",
        "topp",
        "topk",
        "trg_end_id",
        "is_post_ln",
        "no_scale_embedding",
        "use_gelu",
        "is_multilingual",
    ]

    EMBEDDING_KEYS = [
        # src_embedding
        # trg_embedding
        "token_embedding",
        "position_embedding",
        "norm_scale",
        "norm_bias",
        "encode_output_project_kernel_kv",
        "encode_output_project_bias_kv",
        "shared_bias",
        "lang_emb",
        "trg_vocab_mask",
    ]

    ENCODER_LAYER_KEYS = [
        # encoder_stack/{i}
        "multihead_norm_scale",
        "multihead_norm_bias",
        "multihead_project_kernel_qkv",
        "multihead_project_bias_qkv",
        "multihead_project_kernel_output",
        "multihead_project_bias_output",
        "ffn_norm_scale",
        "ffn_norm_bias",
        "ffn_first_kernel",
        "ffn_first_bias",
        "ffn_second_kernel",
        "ffn_second_bias",
    ]

    DECODER_LAYER_KEYS = [
        # decoder_stack/{i}
        "self_norm_scale",
        "self_norm_bias",
        "self_project_kernel_qkv",
        "self_project_bias_qkv",
        "self_project_kernel_output",
        "self_project_bias_output",
        "encdec_norm_scale",
        "encdec_norm_bias",
        "encdec_project_kernel_q",
        "encdec_project_bias_q",
        "encdec_project_kernel_output",
        "encdec_project_bias_output",
        "ffn_norm_scale",
        "ffn_norm_bias",
        "ffn_first_kernel",
        "ffn_first_bias",
        "ffn_second_kernel",
        "ffn_second_bias",
    ]
    base_attr_to_keys = {
        "src_embedding": EMBEDDING_KEYS,
        "trg_embedding": EMBEDDING_KEYS,
        "model_conf": MODEL_CONF_KEYS,
    }

    print(f"start converting protobuf to hdf5 format.")
    # load src_embedding, trg_embedding, model_conf
    for base_attr, keys in base_attr_to_keys.items():
        for key in keys:
            hdf5_key = f"{base_attr}/{key}"
            proto_attr = f"{base_attr}.{key}"

            if key not in dir(attrgetter(base_attr)(transformer)):
                print(f"key {key} not found in {base_attr}, skipping")
                continue

            print(f"loading transformer {proto_attr} -> {hdf5_key}")
            _data = attrgetter(proto_attr)(transformer)
            if type(_data) is str:
                print(
                    f"find type str, explicitly convert string to ascii encoded array."
                )
                # explict convert to array of char (int8) to avoid issues on string reading in C
                _data = np.array([ord(c) for c in _data]).astype(np.int8)
            f.create_dataset(hdf5_key, data=_data)

    # save number of layers metadata
    f.create_dataset("model_conf/n_encoder_stack", data=len(transformer.encoder_stack))
    f.create_dataset("model_conf/n_decoder_stack", data=len(transformer.decoder_stack))

    # load encoder_stack
    for layer_id, layer in enumerate(transformer.encoder_stack):
        for key in ENCODER_LAYER_KEYS:
            hdf5_key = f"encoder_stack/{layer_id}/{key}"
            proto_attr = key
            print(f"loading transformer.encoder_stack {proto_attr} -> {hdf5_key}")
            f.create_dataset(hdf5_key, data=attrgetter(proto_attr)(layer))

    # load decoder_stack
    for layer_id, layer in enumerate(transformer.decoder_stack):
        for key in DECODER_LAYER_KEYS:
            hdf5_key = f"decoder_stack/{layer_id}/{key}"
            proto_attr = key
            print(f"loading transformer.decoder_stack {proto_attr} -> {hdf5_key}")
            f.create_dataset(hdf5_key, data=attrgetter(proto_attr)(layer))

    print(f"proto to hdf5 conversion completed.")

def extract_transformer_weights(
    output_file,
    model_dir,
    generation_method,
    max_step,
    extra_decode_length=50,
    beam_size=4,
    length_penalty: float = 0,
    topk=1,
    topp=0.75,
    lang="en",
    only_decoder=True,
    save_proto=False,
):
    transformer = Transformer()
    # load var names
    model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    assert model.config.encoder_attention_heads == model.config.decoder_attention_heads
    head_num = model.config.encoder_attention_heads
    reloaded = model.state_dict()

    encoder_state_dict = {}
    decoder_state_dict = {}
    for k in reloaded:
        if k.startswith("model.encoder."):
            encoder_state_dict[k] = reloaded[k]
        if k.startswith("model.decoder."):
            decoder_state_dict[k] = reloaded[k]
        if k == "model.shared.weight":
            encoder_state_dict[k] = reloaded[k]
            decoder_state_dict[k] = reloaded[k]
        if k == "final_logits_bias":
            decoder_state_dict[k] = reloaded[k]

    dec_var_name_list = list(decoder_state_dict.keys())
    enc_var_name_list = list(encoder_state_dict.keys())

    # fill each encoder layer's params
    if not only_decoder:
        enc_tensor_names = {}
        for name in enc_var_name_list:
            name_split = name.split(".")
            if len(name_split) <= 3 or not name_split[3].isdigit():
                continue
            layer_id = int(name_split[3])
            enc_tensor_names.setdefault(layer_id, []).append(name)

        for layer_id in sorted(enc_tensor_names.keys()):
            fill_pb_layer(
                enc_tensor_names[layer_id],
                encoder_state_dict,
                transformer.encoder_stack.add(),
                enc_layer_mapping_dict,
            )

    # fill each decoder layer's params
    dec_tensor_names = {}
    for name in dec_var_name_list:
        name_split = name.split(".")
        if len(name_split) <= 3 or not name.split(".")[3].isdigit():
            continue
        layer_id = int(name.split(".")[3])
        dec_tensor_names.setdefault(layer_id, []).append(name)

    for layer_id in sorted(dec_tensor_names.keys()):
        fill_pb_layer(
            dec_tensor_names[layer_id],
            decoder_state_dict,
            transformer.decoder_stack.add(),
            dec_layer_mapping_dict,
        )

    # fill src_embedding
    if not only_decoder:
        # import pdb;pdb.set_trace()
        # MarianMTModel doesn't have encoder layernorm
        # fill_pb_layer(
        #     enc_var_name_list,
        #     encoder_state_dict,
        #     transformer.src_embedding,
        #     src_emb_mapping_dict,
        # )
        # bart position index starts from 2
        # https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/configuration_bart.py#L208
        # https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/modeling_bart.py#L821
        pos_emb_list = (
            encoder_state_dict["model.encoder.embed_positions.weight"]
            .numpy()[
                2 : 2 + max_step, :
            ]  # because in huggingface bart, the position embedding starts from 2
            .reshape([-1])
            .tolist()
        )
        transformer.src_embedding.position_embedding[:] = pos_emb_list
        print(
            "model.encoder.embed_positions.weight -> src_embedding.position_embedding, shape: {}, conversion finished!".format(
                encoder_state_dict["model.encoder.embed_positions.weight"]
                .numpy()[2 : 2 + max_step, :]
                .shape
            )
        )
        src_tb, _ = gather_token_embedding(
            enc_var_name_list, encoder_state_dict, "shared", scale=False
        )
        transformer.src_embedding.token_embedding[:] = src_tb.flatten().tolist()

    # fill trg_embedding
    encode_output_mapping_dict = _get_encode_output_mapping_dict(len(dec_tensor_names))
    trg_emb_mapping_dict.update(encode_output_mapping_dict)
    fill_pb_layer(
        dec_var_name_list,
        decoder_state_dict,
        transformer.trg_embedding,
        trg_emb_mapping_dict,
    )
    pos_emb_list = (
        decoder_state_dict["model.decoder.embed_positions.weight"]
        .numpy()[2 : 2 + max_step, :]
        .reshape([-1])
        .tolist()
    )
    transformer.trg_embedding.position_embedding[:] = pos_emb_list
    print(
        "model.decoder.embed_positions.weight -> trg_embedding.position_embedding, shape: {}, conversion finished!".format(
            decoder_state_dict["model.decoder.embed_positions.weight"]
            .numpy()[:max_step, :]
            .shape
        )
    )
    # assert lang in LANG2ID
    trg_tb, _ = gather_token_embedding(
        dec_var_name_list, decoder_state_dict, "shared", scale=False
    )
    transformer.trg_embedding.token_embedding[:] = trg_tb.transpose().flatten().tolist()
    print(
        "token_embedding.weight -> trg_embedding.token_embedding, shape: {}, conversion finished!".format(
            trg_tb.transpose().shape
        )
    )

    # change encoder layer norm scale&bias position
    tmp_scale, tmp_bias = (
        transformer.src_embedding.norm_scale,
        transformer.src_embedding.norm_bias,
    )
    for i, encoder in enumerate(transformer.encoder_stack):
        print("***Fix encoder layer {} LayerNorm scale and bias***".format(i))
        new_tmp_scale, new_tmp_bias = (
            encoder.multihead_norm_scale[:],
            encoder.multihead_norm_bias[:],
        )
        encoder.multihead_norm_scale[:], encoder.multihead_norm_bias[:] = (
            tmp_scale,
            tmp_bias,
        )
        print(
            "multihead_norm_scale: {} -> {}\nmultihead_norm_bias: {} -> {}".format(
                new_tmp_scale[:3],
                encoder.multihead_norm_scale[:3],
                new_tmp_bias[:3],
                encoder.multihead_norm_bias[:3],
            )
        )
        tmp_scale, tmp_bias = new_tmp_scale[:], new_tmp_bias[:]

        new_tmp_scale, new_tmp_bias = (
            encoder.ffn_norm_scale[:],
            encoder.ffn_norm_bias[:],
        )
        encoder.ffn_norm_scale[:], encoder.ffn_norm_bias[:] = (
            tmp_scale,
            tmp_bias,
        )
        print(
            "ffn_norm_scale: {} -> {}\nffn_norm_bias: {} -> {}".format(
                new_tmp_scale[:3],
                encoder.ffn_norm_scale[:3],
                new_tmp_bias[:3],
                encoder.ffn_norm_bias[:3],
            )
        )
        tmp_scale, tmp_bias = new_tmp_scale[:], new_tmp_bias[:]
    transformer.src_embedding.norm_scale[:], transformer.src_embedding.norm_bias[:] = (
        tmp_scale,
        tmp_bias,
    )

    # change decoder layer norm scale&bias position
    tmp_scale, tmp_bias = (
        transformer.trg_embedding.norm_scale,
        transformer.trg_embedding.norm_bias,
    )
    for i, decoder in enumerate(transformer.decoder_stack):
        print("***Fix decoder layer {} LayerNorm scale and bias***".format(i))
        new_tmp_scale, new_tmp_bias = (
            decoder.self_norm_scale[:],
            decoder.self_norm_bias[:],
        )
        decoder.self_norm_scale[:], decoder.self_norm_bias[:] = tmp_scale, tmp_bias
        print(
            "self_norm_scale: {} -> {}\nself_norm_bias: {} -> {}".format(
                new_tmp_scale[:3],
                decoder.self_norm_scale[:3],
                new_tmp_bias[:3],
                decoder.self_norm_bias[:3],
            )
        )
        tmp_scale, tmp_bias = new_tmp_scale[:], new_tmp_bias[:]

        new_tmp_scale, new_tmp_bias = (
            decoder.encdec_norm_scale[:],
            decoder.encdec_norm_bias[:],
        )
        decoder.encdec_norm_scale[:], decoder.encdec_norm_bias[:] = tmp_scale, tmp_bias
        print(
            "encdec_norm_scale: {} -> {}\nencdec_norm_bias: {} -> {}".format(
                new_tmp_scale[:3],
                decoder.encdec_norm_scale[:3],
                new_tmp_bias[:3],
                decoder.encdec_norm_bias[:3],
            )
        )
        tmp_scale, tmp_bias = new_tmp_scale[:], new_tmp_bias[:]

        new_tmp_scale, new_tmp_bias = (
            decoder.ffn_norm_scale[:],
            decoder.ffn_norm_bias[:],
        )
        decoder.ffn_norm_scale[:], decoder.ffn_norm_bias[:] = (
            tmp_scale,
            tmp_bias,
        )
        print(
            "ffn_norm_scale: {} -> {}\nffn_norm_bias: {} -> {}".format(
                new_tmp_scale[:3],
                decoder.ffn_norm_scale[:3],
                new_tmp_bias[:3],
                decoder.ffn_norm_bias[:3],
            )
        )
        tmp_scale, tmp_bias = new_tmp_scale[:], new_tmp_bias[:]
    transformer.trg_embedding.norm_scale[:], transformer.trg_embedding.norm_bias[:] = (
        tmp_scale,
        tmp_bias,
    )

    # fill in conf

    transformer.model_conf.head_num = head_num

    transformer.model_conf.beam_size = beam_size
    transformer.model_conf.length_penalty = length_penalty

    transformer.model_conf.extra_decode_length = extra_decode_length
    transformer.model_conf.src_padding_id = tokenizer.pad_token_id
    transformer.model_conf.trg_start_id = tokenizer.pad_token_id
    transformer.model_conf.trg_end_id = tokenizer.eos_token_id

    transformer.model_conf.sampling_method = generation_method
    transformer.model_conf.topk = topk
    transformer.model_conf.topp = topp
    transformer.model_conf.diverse_lambda = 0
    transformer.model_conf.is_post_ln = True
    transformer.model_conf.no_scale_embedding = not model.config.scale_embedding
    transformer.model_conf.use_gelu = True

    if save_proto:
        output_file += ".pb"
        print("Saving model to protobuf...")
        print("Writing to {0}".format(output_file))
        with tf.io.gfile.GFile(output_file, "wb") as fout:
            fout.write(transformer.SerializeToString())

        transformer = Transformer()
        with tf.io.gfile.GFile(output_file, "rb") as fin:
            transformer.ParseFromString(fin.read())
        print(transformer.model_conf)
    else:
        output_file += ".hdf5"
        print("Saving model to hdf5...")
        print("Writing to {0}".format(output_file))
        f = h5py.File(output_file, "w")
        save_bart_proto_to_hdf5(transformer, f)
        f.close()

        f = h5py.File(output_file, "r")

        def _print_pair(key, value):
            if key == "sampling_method":
                value = "".join(map(chr, value[()]))
            else:
                value = value[()]
            print(f"{key}: {value}")

        list(map(lambda x: _print_pair(*x), f["model_conf"].items()))
        f.close()
    return transformer

if __name__ == "__main__":
    args = parse_args()
    if args.generation_method not in ["beam_search", "topk", "topp", "topk_greedy"]:
        args.generation_method = "beam_search"
    # if save_proto is True, extension .pb will be added, otherwise .hdf5 is added
    output_lightseq_model_name = "opus-mt-en-fr"  # you can rename it to "lightseq_bart_large" for large model
    input_huggingface_bart_model = (
        'Helsinki-NLP/opus-mt-en-fr' # Example: you can try "facebook/bart-large" as well
    )
    beam_size = 4
    max_step = 50  # max step for generation, it decides GPU memory occupancy
    # maximum_generation_length = min(src_length + extra_decode_length, max_step)
    extra_decode_length = 50
    length_penalty = 1.0
    res = extract_transformer_weights(
        output_lightseq_model_name,
        input_huggingface_bart_model,
        generation_method=args.generation_method,
        beam_size=beam_size,
        max_step=max_step,
        extra_decode_length=extra_decode_length,
        only_decoder=False,
        length_penalty=length_penalty,
        save_proto=False,
    )
Youggls commented 1 year ago

I guess this error occurred at transformer_weight.cc and the reason is that the _hidden_size is initialized by _hidden_size = transformer.trg_embedding().norm_scale_size(). However MarianMTModel doesn't have this module. Do I have any solution to solve this question?

zjersey commented 1 year ago

@Youggls In Lightseq, layernorm_embedding weight is stored as self_attn_layer_norm in the first encoder(decoder) layer; while the norm_scale and norm_bias stored in EmbeddingLayer represent the "en(de)coder output layernorm" for pre-norm, or "the last ffn layer norm" for post-norm.

Youggls commented 1 year ago

@zjersey Sorry, I might make a mistake previously. In huggingface,BartForConditionalGeneration model's en(de)coder has two module which is called encoder.layernorm_embedding and decoder.layernorm_embedding. As huggingface's source code shown, the two modules are used to norm the output of embeddinglayer. They should be consistent with the norm_scale and norm_bias in the lightseq. MarianMTModel's state_dict is basically same with BartForConditionalGeneration, the difference is that MarianMTModel doesn't have the two module encoder.layernorm_embedding and decoder.layernorm_embedding. The output of embeddinglayer will directly feed to EncoderLayer (source code).

In this settings, what should I do if I want use MarianMTModel?

Youggls commented 1 year ago

@zjersey Sorry, I might make a mistake previously. In huggingface,BartForConditionalGeneration model's en(de)coder has two module which is called encoder.layernorm_embedding and decoder.layernorm_embedding. As huggingface's source code shown, the two modules are used to norm the output of embeddinglayer. They should be consistent with the norm_scale and norm_bias in the lightseq. MarianMTModel's state_dict is basically same with BartForConditionalGeneration, the difference is that MarianMTModel doesn't have the two module encoder.layernorm_embedding and decoder.layernorm_embedding. The output of embeddinglayer will directly feed to EncoderLayer (source code).

In this settings, what should I do if I want use MarianMTModel?

By the way, we can use the code to see the differences between MarianMTModel and BartForConditionalGeneration

from transformers import AutoModelForSeq2SeqLM
marian = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-fr')
bart = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base')
print('*** bart special module ***')
for key in bart.state_dict().keys():
     if key not in marian.state_dict():
         print(key)
print('*** marian special module ***')
for key in marian.state_dict().keys():
     if key not in bart.state_dict():
         print(key)
zjersey commented 1 year ago

@Youggls For now this model is not supported.

For BART, layernorm_embedding operates in self-attention of the first layer, while the kernel ker_norm_layer_resual_launcher doesn't support layernorm pointers to be nullptr for now.

We will consider supporting this model in the future.

Youggls commented 1 year ago

@Youggls For now this model is not supported.

For BART, layernorm_embedding operates in self-attention of the first layer, while the kernel ker_norm_layer_resual_launcher doesn't support layernorm pointers to be nullptr for now.

We will consider supporting this model in the future.

@zjersey Thanks for your reply. Lightseq is really a good library. Helsinki-NLP provieds many pretrained machine-translation MarianMTModel models. If lightseq can support this model, I believe this will help more researchers and developers. Hope lightseq can support this model in the future.