huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.62k stars 27.15k forks source link

Advice on creating/wrapping `PreTrainedModel` to be compatible with the codebase? #10571

Closed HanGuo97 closed 3 years ago

HanGuo97 commented 3 years ago

Environment info

Who can help

@patrickvonplaten

Information

Model I am using (Bert, XLNet ...):

The problem arises when using:

The tasks I am working on is:

To reproduce

Thanks for the amazing library!

I'm curious if there are instructions on creating a PreTrainedModel subclass or creating an nn.Module that behaves like a PreTrainedModel? Suppose I want to wrap the existing model with some simple additional capabilities inside an nn.Module, what are some of the methods that I need to implement/override -- so that they can work well with existing examples?

I'm aware of some tutorials on creating a new model, but that seems pretty complicated and involved -- whereas I'm interested in just adding a couple of simple features.

For example, in the Seq2Seq example, I have noticed that the function signature of model.forward determines what data will (not) be passed to the model (as in trainer._remove_unused_columns), and the existence of model.prepare_decoder_input_ids_from_labels also influences the input data (as in DataCollatorForSeq2Seq .__call__).

It'd be great if someone could point me to some guidance on tweaking the model to be compatible with the rest of the codebase. Thanks in advance for your time!

Steps to reproduce the behavior:

1. 2. 3.

Expected behavior

patrickvonplaten commented 3 years ago

Hey @HanGuo97,

We try to keep the GitHub issues for bug reports. Do you mind asking your question on the forum instead? Also there might already be similar questions on the forum, such as https://discuss.huggingface.co/t/create-a-custom-model-that-works-with-any-pretrained-transformer-body/4186. Thanks!

HanGuo97 commented 3 years ago

Got it, thanks for letting me know!