triton-inference-server / fastertransformer_backend

BSD 3-Clause "New" or "Revised" License
411 stars 133 forks source link

Supporting for Flan-t5 with gated activation and non-shared embeddings #82

Closed LydiaXiaohongLi closed 1 year ago

LydiaXiaohongLi commented 1 year ago

Hi Team.

I am following https://github.com/triton-inference-server/fastertransformer_backend/blob/dev/t5_gptj_blog/notebooks/GPT-J_and_T5_inference.ipynb for triton inference with T5. I am able to replicate it for t5-3b, I do see the output generated tallies with original t5-3b model and improvements of inference speed. However, if I change the model to flan-t5-xl, the output generated is way off as the original flan-t5-xl. Note: I did some modification of FasterTransformer/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py file, so that the additional wi_0, wi_1 parameters of the gated activation and non shared lm_head embeddings, could be downloaded and converted.

I guess flan-t5 is not supported due to this slight change in model structure, by simply following the https://github.com/triton-inference-server/fastertransformer_backend/blob/dev/t5_gptj_blog/notebooks/GPT-J_and_T5_inference.ipynb`? If so, how could I make it work for flan-t5 models?

Thank you!

Appendix: I just added below blocks into FasterTransformer/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py split_and_convert_process function

elif key.find("lm_head.weight") != -1:
        # shared weights, only need to convert the weights of rank 0
        saved_path = saved_dir / f"{saved_key}.bin"
        val.tofile(saved_path.as_posix())

        saved_path = saved_dir / f"{saved_key}_T.bin"
        val.T.tofile(saved_path.as_posix())
elif (
        key.find("DenseReluDense.wi_0.weight") != -1 
        or (key.find("encoder") != -1 and (
            key.find("SelfAttention.q.weight") != -1
            or key.find("SelfAttention.k.weight") != -1
            or key.find("SelfAttention.v.weight") != -1
            )
            )
        or key.find("EncDecAttention.q.weight") != -1 
        or key.find("EncDecAttention.k.weight") != -1 
        or key.find("EncDecAttention.v.weight") != -1 
        ):
        split_vals = np.split(val, factor, axis=-1)
        for j in range(factor):
            saved_path = saved_dir / f"{saved_key}.{j:d}.bin"
            split_vals[j].tofile(saved_path.as_posix()) 
elif (
        key.find("DenseReluDense.wi_1.weight") != -1 
        or (key.find("encoder") != -1 and (
            key.find("SelfAttention.q.weight") != -1
            or key.find("SelfAttention.k.weight") != -1
            or key.find("SelfAttention.v.weight") != -1
            )
            )
        or key.find("EncDecAttention.q.weight") != -1 
        or key.find("EncDecAttention.k.weight") != -1 
        or key.find("EncDecAttention.v.weight") != -1 
        ):
        split_vals = np.split(val, factor, axis=-1)
        for j in range(factor):
            saved_path = saved_dir / f"{saved_key}.{j:d}.bin"
            split_vals[j].tofile(saved_path.as_posix()) 
byshiue commented 1 year ago

Please try the following scripts on latest main branch. You don't need to do any modification on converter.

sudo apt-get install git-lfs
git lfs install
git lfs clone https://huggingface.co/google/flan-t5-small

python3 ./build/_deps/repo-ft-src/examples/pytorch/t5/utils/huggingface_t5_ckpt_convert.py \
        -saved_dir flan-t5-small/c-models \
        -in_file flan-t5-small/ \
        -inference_tensor_para_size 1 \
        -weight_data_type fp32
tritonserver --model-repository=all_models/t5/ &
python3 tools/t5_utils/summarization.py --ft_model_location flan-t5-small/c-models/1-gpu/ \
                                        --hf_model_location flan-t5-small/ \
                                        --test_ft \
                                        --test_hf \
                                        --data_type fp16
LydiaXiaohongLi commented 1 year ago

Thank you!

lakshaykc commented 1 year ago

@byshiue I ran tests according to your instructions above and it works. However, the quality of FT outputs degrades compared to HF outputs as we increase the model size especially for flan-t5-xl and flan-t5-xxl. In fact flan-t5-xxl output is meaningless. I've create a new issue (#95) for this.