guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.64k stars 215 forks source link

Potentially use BetterTransformer from PyTorch #302

Open lerouxrgd opened 1 year ago

lerouxrgd commented 1 year ago

Hello,

As describe on PyTorch's blog since version 1.12 it is possible to have significantly faster transformers.

To benefit from it in Python one has to use pre-built modules such as TransformerEncoder. Looking at the source code it seems to boil down to using _transformer_encoder_layer_fwd which is also available in tch-rs.

Do you think it would be possible to make use of it in rust-bert ? I can have a look at it if you think it is worth investigating.

guillaume-be commented 1 year ago

Hello @lerouxrgd ,

Yes the availability of BetterTransformer is an interesting development. The challenge for an integration in the library is twofold:

  1. a lot of the language models implemented implement the attention mechanism from scratch, often with subtle differences that may differ from the BetterTransformer module.
  2. even if the logic of the transformer block would be identical between the base implementation and BetterTransformer, the submodule and parameters may have different names that will not be loaded correctly using a torch.load_state_dict in Python (or varstore.load in the Rust version). The weight may have to be re-exported with updated variable names causing a lack of backward compatibility if the old one is removed.

It may be worth to keep an eye on the related issues on the Python's library (e.g. https://github.com/huggingface/transformers/issues/20372 , https://github.com/huggingface/transformers/pull/19966, https://github.com/huggingface/transformers/pull/19632) and the documentation page at https://huggingface.co/docs/optimum/bettertransformer/tutorials/contribute