Closed SeanNaren closed 3 years ago
I should be able to do this for all tasks:
import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import ModelPruning from transformers import AutoTokenizer from lightning_transformers.core.nlp.huggingface.config import HFBackboneConfig from lightning_transformers.task.nlp.translation import WMT16TranslationDataModule, TranslationTransformer from lightning_transformers.task.nlp.translation.config import TranslationDataConfig class MyTranslationTransformer(TranslationTransformer): def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1e-5) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='t5-base') model = MyTranslationTransformer( 'transformers.AutoModelForSeq2SeqLM', HFBackboneConfig(pretrained_model_name_or_path='t5-base') ) dm = WMT16TranslationDataModule( cfg=TranslationDataConfig(source_language='en', target_language='ro'), tokenizer=tokenizer ) trainer = pl.Trainer( gpus=1, precision=16, max_epochs=5, callbacks=[ModelPruning('l1_unstructured', use_lottery_ticket_hypothesis=True)] ) trainer.fit(model, dm)
Crashes:
Traceback (most recent call last): File "test.py", line 18, in <module> model = TranslationTransformer( File "/home/sean/lightning-transformers/lightning_transformers/task/nlp/translation/model.py", line 12, in __init__ super().__init__(*args, **kwargs) TypeError: __init__() missing 1 required keyword-only argument: 'cfg'
🚀 Feature
I should be able to do this for all tasks:
Crashes: