syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

How to load my own model #12

Closed zhihui-shao closed 1 year ago

zhihui-shao commented 1 year ago

I trained a model using train.py and got the checkpoint folder, how do I load this model for inference?

syncdoth commented 1 year ago

Since the model almost fully supports huggingface transformers API, you can load your model with .from_pretrained(CKPT_DIR).

Example:

from retnet.modeling_retnet import RetNetModelWithLMHead
from transformers import AutoTokenizer

model = RetNetModelWithLMHead.from_pretrained('checkpoints/checkpoint-xxx')
# Trainer API should save tokenizers too in the same directory
tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-xxx')