Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
610 stars 77 forks source link

Ensure non-hydra compatability #113

Closed SeanNaren closed 3 years ago

SeanNaren commented 3 years ago

🚀 Feature

I should be able to do this for all tasks:

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelPruning
from transformers import AutoTokenizer

from lightning_transformers.core.nlp.huggingface.config import HFBackboneConfig
from lightning_transformers.task.nlp.translation import WMT16TranslationDataModule, TranslationTransformer
from lightning_transformers.task.nlp.translation.config import TranslationDataConfig

class MyTranslationTransformer(TranslationTransformer):
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-5)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='t5-base')

model = MyTranslationTransformer(
    'transformers.AutoModelForSeq2SeqLM',
    HFBackboneConfig(pretrained_model_name_or_path='t5-base')
)

dm = WMT16TranslationDataModule(
    cfg=TranslationDataConfig(source_language='en', target_language='ro'),
    tokenizer=tokenizer
)

trainer = pl.Trainer(
    gpus=1,
    precision=16,
    max_epochs=5,
    callbacks=[ModelPruning('l1_unstructured', use_lottery_ticket_hypothesis=True)]
)

trainer.fit(model, dm)

Crashes:

Traceback (most recent call last):
  File "test.py", line 18, in <module>
    model = TranslationTransformer(
  File "/home/sean/lightning-transformers/lightning_transformers/task/nlp/translation/model.py", line 12, in __init__
    super().__init__(*args, **kwargs)
TypeError: __init__() missing 1 required keyword-only argument: 'cfg'