bytedance / lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
Other
3.2k stars 329 forks source link

export and infer fairseq transformer model #438

Open koukan3 opened 1 year ago

koukan3 commented 1 year ago

Hi, I trained a translation model with fairseq tool . The model structure is as follows:

encoder.embed_tokens.weight encoder.embed_positions._float_tensor encoder.layernorm_embedding.weight encoder.layernorm_embedding.bias encoder.layers.0.self_attn.k_proj.weight encoder.layers.0.self_attn.k_proj.bias ...... encoder.layers.11.final_layer_norm.weight encoder.layers.11.final_layer_norm.bias encoder.layer_norm.weight encoder.layer_norm.bias decoder.embed_tokens.weight decoder.embed_positions._float_tensor decoder.layernorm_embedding.weight decoder.layernorm_embedding.bias decoder.layers.0.self_attn.k_proj.weight decoder.layers.0.self_attn.k_proj.bias ...... decoder.layer_norm.weight decoder.layer_norm.bias decoder.output_projection.weight

I can't find the field name corresponding to layernorm_embedding in the transformer.proto file, and I think "norm_scale" field is for "layer_norm weight". May I know how to assign the value of layernorm_embedding (e.g. encoder.layernorm_embedding.weight) to Transformer object ? Leave the above problem aside, I ran native_fs_transformer_export.py and a protobuf format model was exported successfully, however, the inference result is not as expected: src = [[87189, 80536, 27780, 2019, 89359, 87300, 90001]] model.infer(src) (array([[[92211, 92211, 92211, 92211, 92211, 2]]], dtype=int32), array([[0.]], dtype=float32)) the result is unusual with some repeated id. I think there is some problem with exporting work. Can you help me figure out how it happens ? Thank you.

Taka152 commented 1 year ago

If your model has layernorm_embedding and layernorm_before=True, it isn't supported by lightseq now. lightseq only supported layernorm_bedding and layernorm_before=False, which you could check huggingface_bart for exportation.

frankang commented 1 year ago

Does lightseq support layernorm_embedding=False and layernorm_before=False? The conversion script native_fs_transformer_export.py defines a mandatory mapping dict for source and target embedding layer norm weight (src_emb_mapping_dict and trg_emb_mapping_dict). @Taka152

Taka152 commented 1 year ago

Does lightseq support layernorm_embedding=False and layernorm_before=False? The conversion script native_fs_transformer_export.py defines a mandatory mapping dict for source and target embedding layer norm weight (src_emb_mapping_dict and trg_emb_mapping_dict). @Taka152

nope, currently, we only support bert and vanilla transformer, which is layernorm_embedding=True && layernorm_before=False and layernorm_embedding=False && layernorm_before=True.

frankang commented 1 year ago

Thanks for the clarification. Since the conversion script native_fs_transformer_export.py requires the existence of the embedding layernorm weight (see below code reference), how can I convert a native fairseq model trained with layernorm_embedding=False && layernorm_before=True? Can I Just delete those lines or is there a variable to set like the is_post_ln attribute in the export_ls_config function?

https://github.com/bytedance/lightseq/blob/f8ac0cba3c5049b1a768da834254b50023e4cf37/examples/inference/python/export/fairseq/native_fs_transformer_export.py#L59-L70

https://github.com/bytedance/lightseq/blob/f8ac0cba3c5049b1a768da834254b50023e4cf37/examples/inference/python/export/fairseq/native_fs_transformer_export.py#L157-L162

baoguo1995 commented 1 year ago

Does lightseq support layernorm_embedding=False and layernorm_before=False? The conversion script native_fs_transformer_export.py defines a mandatory mapping dict for source and target embedding layer norm weight (src_emb_mapping_dict and trg_emb_mapping_dict). @Taka152

nope, currently, we only support bert and vanilla transformer, which is layernorm_embedding=True && layernorm_before=False and layernorm_embedding=False && layernorm_before=True.

My transformer model is layernorm_embedding=False && layernorm_before=True, but I have also encountered the same problem. image