Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
610 stars 77 forks source link

LightningCLI compatibility and Type["AutoModel"] #274

Closed mauvilsa closed 2 years ago

mauvilsa commented 2 years ago

In a jsonargparse issue it was commented that to make lightning-transformers work, it was required to change the type Type["AutoModel"] to just Type, see jsonargparse#146 comment. Forward references like "AutoModel" are not currently supported in jsonargparse. This support could be added, but looking at the source of lightning-transformers this forward reference seems like a mistake. A type hint like Type[<class>] means that the accepted values are the <class> or any subclass of it. I could be wrong, but it seemed to me that the models don't inherit from AutoModel. Is this just a bug and the type should be just Type?

https://github.com/Lightning-AI/lightning-transformers/blob/bf8a215af13d07c31b052788f7b694412e6a32f7/lightning_transformers/core/model.py#L50

Borda commented 2 years ago

@carmocca, mind having a look? :rabbit:

carmocca commented 2 years ago

but it seemed to me that the models don't inherit from AutoModel

IIRC AutoModel is the API that defines .from_pretrained(): https://github.com/Lightning-AI/lightning-transformers/blob/master/lightning_transformers/core/model.py#L77. But not a HF expert.

mauvilsa commented 2 years ago

AutoModel inherits from a class _BaseAutoModelClass (see modeling_auto.py#L838) and that one does define .from_pretrained(), but my question remains. I debugged a couple of lightning-transformers unit tests to see which kind of classes were given to the downstream_model_type parameter. One example is transformers.models.auto.modeling_auto.AutoModelForCausalLM. This does not inherit from AutoModel, but it does inherit from _BaseAutoModelClass. See:

>>> import transformers
>>> issubclass(transformers.models.auto.modeling_auto.AutoModelForCausalLM, transformers.AutoModel)
False
>>> issubclass(transformers.models.auto.modeling_auto.AutoModelForCausalLM, transformers.models.auto.modeling_auto._BaseAutoModelClass)
True

The type hint would need to be Type[_BaseAutoModelClass] instead of Type[AutoModel].

A related question, why does it have to be a forward reference?

carmocca commented 2 years ago

The type hint would need to be Type[_BaseAutoModelClass] instead of Type[AutoModel].

Fine, this can be changed.

why does it have to be a forward reference?

It doesn't have to. It's like this because it's conditionally imported under TYPE_CHECKING, probably to avoid unnecessarily importing transformers for just an annotation. Could be changed.