JulesBelveze / bert-squeeze

🛠️ Tools for Transformers compression using PyTorch Lightning ⚡
https://julesbelveze.github.io/bert-squeeze/
78 stars 10 forks source link

Seq2Seq distillation `InterpolationKeyError` #66

Closed JulesBelveze closed 4 months ago

JulesBelveze commented 4 months ago

As reported by @tarekziade the following code leads a hydra InterpolationKeyError:

from bert_squeeze.assistants import DistilAssistant
from lightning.pytorch import Trainer

config_assistant = {
    "teacher_kwargs": {
        "pretrained_model": "cnicu/t5-small-booksum",
    },
    "student_kwargs": {
        "pretrained_model": "cnicu/t5-small-booksum",
        "num_decoder_layers": 6,
    },
    "data_kwargs": {
        "teacher_module": {
            "dataset_config": {
                "path": "kmfoda/booksum",
                "target_col": "summary_text",
                "source_col": "chapter",
            }
        }
    },
    "callbacks": [
        {
            "_target_": "bert_squeeze.utils.callbacks.pruning.ThresholdBasedPruning",
            "threshold": 0.2,
            "start_pruning_epoch": -1,
        },
        {"_target_": "bert_squeeze.utils.callbacks.quantization.DynamicQuantization"},
    ],
}

assistant = DistilAssistant("distil-seq2seq", **config_assistant)

model = assistant.model
callbacks = assistant.callbacks

train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()

basic_trainer = Trainer(max_steps=2, callbacks=callbacks)

basic_trainer.fit(
    model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader
)
[Copy Snippet](https://paste.mozilla.org/GOrLYHVn#copy)
[Edit Snippet](https://paste.mozilla.org/GOrLYHVn#edit)
 Wordwrap
paste.mozilla.org allows you to share code snippets and notes with others. These pastes require a link to be viewed; they are not private. Anyone with the link is able to see the paste and also delete it.

Please refrain from sharing personal or sensitive information on this website to avoid it being viewed by other parties.

from bert_squeeze.assistants import DistilAssistant
from lightning.pytorch import Trainer
​
​
config_assistant = {
    "teacher_kwargs": {
        "pretrained_model": "cnicu/t5-small-booksum",
    },
    "student_kwargs": {
        "pretrained_model": "cnicu/t5-small-booksum",
        "num_decoder_layers": 6,
    },
    "data_kwargs": {
        "teacher_module": {
            "dataset_config": {
                "path": "kmfoda/booksum",
                "target_col": "summary_text",
                "source_col": "chapter",
            }
        }
    },
    "callbacks": [
        {
            "_target_": "bert_squeeze.utils.callbacks.pruning.ThresholdBasedPruning",
            "threshold": 0.2,
            "start_pruning_epoch": -1,
        },
        {"_target_": "bert_squeeze.utils.callbacks.quantization.DynamicQuantization"},
    ],
}
​
​
assistant = DistilAssistant("distil-seq2seq", **config_assistant)
​
model = assistant.model
callbacks = assistant.callbacks
​
train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()
​
basic_trainer = Trainer(max_steps=2, callbacks=callbacks)
​
basic_trainer.fit(
    model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader
)