bytedance / lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
Other
3.13k stars 324 forks source link

support for MBART (big models)? #114

Open leoozy opened 2 years ago

leoozy commented 2 years ago

Hello, tThank you for your contribution. Howeverm I notice that all mbart models exceed 2GB. Do you have any plan to fix this issue?

leoozy commented 2 years ago

Currently, it looks that the tool do not support models exceeding 2GB

Taka152 commented 2 years ago

you can check here #63

leoozy commented 2 years ago

@Taka152 Hello, thank you for your reply. I am trying to accelerate the MBart Model. But the model size is too large. Could the main branch solve the issue as I noticed some comments about the large models in the main branch.

I change the number of encoder/decoder of Mbart model to 2 in the config.json file. But the error still exists (Bytesize exceed 2GB). This is almost impossible for a 2 encoder / 2 - decoder mbart model. Do you know why? Thank you !

leoozy commented 2 years ago

initializing bart tokenizer... creating lightseq model... Parsing hdf5: /home/sysadmin/downlaod/lightseq_models/lightseq_mbart_base.hdf5 loading 976 MB of embedding weight. Finish loading src_emb_wei from host to device loading 1073 MB of embedding weight. Finish loading trg_emb_wei from host to device loading 576 MB of encoder weight. Finish loading enc_wei from host to device loading 672 MB of decoder weight. Finish loading dec_wei from host to device Finish loading all weight from host to device model config encoder layers: 12 decoder layers: 12 hidden size: 1024 inner size: 4096 head number: 12 dim per head: 85 src vocab size: 250031 trg vocab size: 250031 is_post_ln: 1 no_scale_embedding: 1 use_gelu: 1 start_id: 2 end_id: 2 padding_id: 1 is_multilingual: 0

generator config beam size: 4 extra decode length(max decode length - src input length): 50 length penalty: 1 diverse lambda: 0 sampling method: beam_search topk: 1 topp: 0.75 Traceback (most recent call last): File "ls_bart.py", line 102, in main() File "ls_bart.py", line 69, in main ls_model = lsi.Transformer("/home/sysadmin/downlaod/lightseq_models/lightseq_mbart_base.hdf5", 128) RuntimeError: violate dim_per_head % 2 = 0

Thank you for your new version. I am trying to accelerate the huggingface Mbart and successfully got the h5 file then. But when I run the "python ls_bart.py", I got this issue. Could you please tell me how to solve it?

lilyzlt commented 2 years ago

initializing bart tokenizer... creating lightseq model... Parsing hdf5: /home/sysadmin/downlaod/lightseq_models/lightseq_mbart_base.hdf5 loading 976 MB of embedding weight. Finish loading src_emb_wei from host to device loading 1073 MB of embedding weight. Finish loading trg_emb_wei from host to device loading 576 MB of encoder weight. Finish loading enc_wei from host to device loading 672 MB of decoder weight. Finish loading decwei from host to device Finish loading all weight from host to device model config_ encoder layers: 12 decoder layers: 12 hidden size: 1024 inner size: 4096 head number: 12 dim per head: 85 src vocab size: 250031 trg vocab size: 250031 is_post_ln: 1 no_scale_embedding: 1 use_gelu: 1 start_id: 2 end_id: 2 padding_id: 1 is_multilingual: 0

generator config beam size: 4 extra decode length(max decode length - src input length): 50 length penalty: 1 diverse lambda: 0 sampling method: beam_search topk: 1 topp: 0.75 Traceback (most recent call last): File "ls_bart.py", line 102, in main() File "ls_bart.py", line 69, in main ls_model = lsi.Transformer("/home/sysadmin/downlaod/lightseq_models/lightseq_mbart_base.hdf5", 128) RuntimeError: violate dim_per_head % 2 = 0

Thank you for your new version. I am trying to accelerate the huggingface Mbart and successfully got the h5 file then. But when I run the "python ls_bart.py", I got this issue. Could you please tell me how to solve it?

I have the same issue