Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

`.forward()` got an unexpected keyword argument when using `QuantizationAwareTraining` #15862

Closed bilelomrani1 closed 1 year ago

bilelomrani1 commented 1 year ago

Bug description

When using multiple input tensors to the .forward method and when using the QuantizationAwareTraining callback, the inputs are not being handled properly by the callback. This may be related to #8677 which is not solved currently.

If I understand this code correctly, restrictive assumptions are being made on the type of inputs of the .forward method (must be a single input tensor, which is quantized before the forward pass). This is generally too restrictive: with the HuggingFace transformers library for instance, multiple tensors are produced by the tokenizer and must be passed as inputs to the Transformers backbone. Moreover, these input tensors are respectively embedding indices and attention masks (LongTensors), so as far as I understand, they must not be quantized prior to the forward pass, I wonder how this interacts with what is done here...

How to reproduce the bug

import os
from typing import Dict, List

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import QuantizationAwareTraining
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase

class RandomDataset(Dataset):
    def __init__(self):
        self.data = [
            "When he had to picnic on the beach, he purposely put sand in other people’s food.",
            "I love bacon, beer, birds, and baboons.",
            "Improve your goldfish's physical fitness by getting him a bicycle.",
            "The knives were out and she was sharpening hers.",
            "For some unfathomable reason, the response team didn't consider a lack of milk for my cereal as a proper emergency.",
            "Chocolate covered crickets were his favorite snack.",
            "He drank life before spitting it out.",
            "The snow-covered path was no help in finding his way out of the back-country.",
        ]

    def __getitem__(self, index: int) -> str:
        return self.data[index]

    def __len__(self) -> int:
        return len(self.data)

class DataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
        self.tokenizer = tokenizer

    def __call__(self, samples: List[str]) -> BatchEncoding:
        return self.tokenizer(samples, return_tensors="pt", padding=True)

class BoringModel(pl.LightningModule):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__()
        self.transformers_module = AutoModelForSequenceClassification.from_config(config)

    def forward(self, **batch: BatchEncoding) -> Tensor:
        return self.transformers_module.forward(**batch).logits

    def training_step(self, batch: BatchEncoding, _: int) -> Dict[str, Tensor]:
        loss = self.forward(**batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch: BatchEncoding, _: int) -> None:
        loss = self.forward(**batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch: BatchEncoding, _: int) -> None:
        loss = self.forward(**batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.SGD(self.transformers_module.parameters(), lr=0.001)

def run():
    model_name = "roberta-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    data_collator = DataCollator(tokenizer)
    train_data = DataLoader(RandomDataset(), batch_size=2, collate_fn=data_collator, shuffle=True)
    val_data = DataLoader(RandomDataset(), batch_size=2, collate_fn=data_collator, shuffle=False)
    test_data = DataLoader(RandomDataset(), batch_size=2, collate_fn=data_collator, shuffle=False)

    model = BoringModel(config=AutoConfig.from_pretrained(model_name))
    trainer = pl.Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        callbacks=[QuantizationAwareTraining()],
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run()

The previous code runs correctly when the call back is commented out. We get the same exception with this slight variation of the forward method:

    def forward(self, batch: BatchEncoding) -> Tensor:
            return self.transformers_module.forward(**batch).logits

Error messages and logs

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1558: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   0%|                                                                                                    | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/bilelomrani/Documents/ILLUIN.nosync/mre-lightning-qat/mre.py", line 92, in <module>
    run()
  File "/Users/bilelomrani/Documents/ILLUIN.nosync/mre-lightning-qat/mre.py", line 87, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 582, in fit
    call._call_and_handle_interrupt(
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 624, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1140, in _run_stage
    self._run_train()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1163, in _run_train
    self.fit_loop.run()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 214, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 247, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 357, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1305, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1661, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 121, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/optim/sgd.py", line 130, in step
    loss = closure()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 107, in _wrap_closure
    closure_result = closure()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
    step_output = self._step_fn()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/Users/bilelomrani/Documents/ILLUIN.nosync/mre-lightning-qat/mre.py", line 52, in training_step
    loss = self.forward(**batch).sum()
TypeError: BoringModel.forward() got an unexpected keyword argument 'input_ids'
Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]  

Environment


* CUDA:
        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning-utilities: 0.3.0
        - pytorch-lightning: 1.8.3.post1
        - torch:             1.13.0
        - torchmetrics:      0.10.3
* Packages:
        - aiohttp:           3.8.3
        - aiosignal:         1.3.1
        - astroid:           2.12.13
        - async-timeout:     4.0.2
        - attrs:             22.1.0
        - black:             22.10.0
        - certifi:           2022.9.24
        - charset-normalizer: 2.1.1
        - click:             8.1.3
        - dill:              0.3.6
        - filelock:          3.8.0
        - fire:              0.4.0
        - frozenlist:        1.3.3
        - fsspec:            2022.11.0
        - huggingface-hub:   0.11.1
        - idna:              3.4
        - isort:             5.10.1
        - lazy-object-proxy: 1.8.0
        - lightning-utilities: 0.3.0
        - mccabe:            0.7.0
        - multidict:         6.0.2
        - mypy:              0.991
        - mypy-extensions:   0.4.3
        - numpy:             1.23.5
        - packaging:         21.3
        - pathspec:          0.10.2
        - pip:               22.3.1
        - platformdirs:      2.5.4
        - protobuf:          3.20.1
        - pylint:            2.15.7
        - pyparsing:         3.0.9
        - pytorch-lightning: 1.8.3.post1
        - pyyaml:            6.0
        - regex:             2022.10.31
        - requests:          2.28.1
        - setuptools:        63.2.0
        - six:               1.16.0
        - tensorboardx:      2.5.1
        - termcolor:         2.1.1
        - tokenizers:        0.13.2
        - tomli:             2.0.1
        - tomlkit:           0.11.6
        - torch:             1.13.0
        - torchmetrics:      0.10.3
        - tqdm:              4.64.1
        - transformers:      4.24.0
        - typing-extensions: 4.4.0
        - urllib3:           1.26.13
        - wrapt:             1.14.1
        - yarl:              1.8.1
* System:
        - OS:                Darwin
        - architecture:
                - 64bit
                - 
        - processor:         i386
        - python:            3.10.8
        - version:           Darwin Kernel Version 22.1.0: Sun Oct  9 20:14:54 PDT 2022; root:xnu-8792.41.9~2/RELEASE_X86_64

cc @borda

Borda commented 1 year ago

@bilelomrani1 does your example work fine without QAT? I would be passing just the batch instead of **batch

bilelomrani1 commented 1 year ago

@Borda I went with

    def forward(self, batch: BatchEncoding) -> Tensor:
            return self.transformers_module.forward(**batch).logits

I still get an exception (albeit different), the trace is the following:

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/ao/quantization/observer.py:214: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1558: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   0%|                                                                           | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 248, in __getattr__
    return self.data[item]
KeyError: 'detach'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/bilelomrani/Documents/ILLUIN.nosync/mre-lightning-qat/mre.py", line 92, in <module>
    run()
  File "/Users/bilelomrani/Documents/ILLUIN.nosync/mre-lightning-qat/mre.py", line 87, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 582, in fit
    call._call_and_handle_interrupt(
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 624, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1140, in _run_stage
    self._run_train()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1163, in _run_train
    self.fit_loop.run()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 214, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 247, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 357, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1305, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1661, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 121, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/optim/sgd.py", line 130, in step
    loss = closure()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 107, in _wrap_closure
    closure_result = closure()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
    step_output = self._step_fn()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/Users/bilelomrani/Documents/ILLUIN.nosync/mre-lightning-qat/mre.py", line 52, in training_step
    loss = self.forward(batch).sum()
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/pytorch_lightning/callbacks/quantization.py", line 61, in wrapper
    data = model.quant(data)  # type: ignore[operator]
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    hook_result = hook(self, input, result)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/ao/quantization/quantize.py", line 117, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/torch/ao/quantization/fake_quantize.py", line 160, in forward
    self.activation_post_process(X.detach())
  File "/Users/bilelomrani/.pyenv/versions/mre-lightning/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 250, in __getattr__
    raise AttributeError
AttributeError

I confirm that the code runs successfully when the callback is commented out.

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

bilelomrani1 commented 1 year ago

The issue is still relevant

awaelchli commented 1 year ago

FYI, we removed the QAT callback in #16750. As explained in the linked PR:

The QAT callback can no longer be maintained by us. It has many issues that make the callback uneffective and these can't be fixed at the moment.

Users who rely on this callback: You can stay on the Lightning 1.9.x version which gets long-term support (LTS) OR you can copy the callback code and maintain it yourself.

If someone from the community is interested in fixing and maintaining this callback, please let us know.

Since this issue is relatively new, we will keep this one open. We might be able to address this and bring the fix to 1.9.x LTS. @bilelomrani1 Do you have interest in contributing a fix?

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

bilelomrani1 commented 1 year ago

Hi @awaelchli sorry for the delay on this topic. I ended up going with an external library for training aware compression (NNCF by Intel, which fits my needs). I am no longer using native torch for this.

I interfaced NNCF with Lightning through a callback (not super clean yet but gets the job done). It's super simple but if there is any interest in this, I would be glad to contribute back and submit a PR. In any case, it's fine to close this issue, it is not needed anymore.

clementpoiret commented 10 months ago

@bilelomrani1 sorry for this bump. I'd be interested in this callback, and possibly to maintain it too if needed. Just a small question: why OpenVINO NNCF and not Intel Neural Compressor?

bilelomrani1 commented 10 months ago

Hi @clementpoiret, here is the callback

import logging
from typing import Any, Dict, Mapping, Optional, cast

import nncf
import pytorch_lightning as pl
import torch
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from pytorch_lightning.utilities.types import STEP_OUTPUT

class NncfCallback(pl.Callback):
    logger = logging.getLogger(__name__)

    def __init__(self, config: Mapping) -> None:
        super().__init__()
        nncf.NNCFConfig.validate(config)
        self.config = nncf.NNCFConfig(config)

    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.logger.info("Initializing NNCF compression algorithm")
        compression_config = nncf.torch.register_default_init_args(
            nncf_config=self.config,
            train_loader=trainer.datamodule.nncf_initializing_dataloader(),  # type: ignore[attr-defined]
            criterion=pl_module.criterion,
        )
        pl_module.compression_controller, pl_module.model = nncf.torch.create_compressed_model(
            model=pl_module.model,
            config=compression_config,
            dump_graphs=False,
        )
        if torch.distributed.is_initialized():
            cast(PTCompressionAlgorithmController, pl_module.compression_controller).distributed()

    def on_train_batch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
    ) -> None:
        # Add the compression objective to the loss
        assert isinstance(outputs, dict)
        compression_loss = cast(PTCompressionAlgorithmController, pl_module.compression_controller).loss()
        pl_module.log("train/base_loss", outputs["loss"])
        outputs["loss"] += compression_loss
        outputs["compression_loss"] = compression_loss
        pl_module.log("train/compression_loss", compression_loss)

    def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        cast(PTCompressionAlgorithmController, pl_module.compression_controller).scheduler.step()

    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        cast(PTCompressionAlgorithmController, pl_module.compression_controller).scheduler.epoch_step()

    def on_validation_batch_end(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        outputs: Optional[STEP_OUTPUT],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        assert isinstance(outputs, dict)
        compression_loss = cast(PTCompressionAlgorithmController, pl_module.compression_controller).loss()
        pl_module.log("val/base_loss", outputs["loss"])
        outputs["loss"] += compression_loss
        outputs["compression_loss"] = compression_loss
        pl_module.log("val/compression_loss", compression_loss)

    def on_save_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
    ) -> None:
        checkpoint["nncf_config"] = self.config
        if hasattr(pl_module, "compression_controller"):
            checkpoint["compression_state"] = cast(
                PTCompressionAlgorithmController, pl_module.compression_controller
            ).get_compression_state()

The DataModule should have a .nncf_initializing_dataloader() method that returns a nncf.torch.initialization.PTInitializingDataLoader

I wrote it a while ago with nncf==2.4.0, some things may have changed since then. I don't remember exactly why I chose this specific framework, I was looking for a training-aware quantization implementation, maybe Intel Neural Compressor has such features but I don't remember testing it, perhaps worth taking a look. Do you have an opinion on the difference between the two?

clementpoiret commented 10 months ago

Thanks for the callback @bilelomrani1 ! I feel that NNCF integrates well with the openvino ecosystem with specific optimizations, but that intel neural compressor might be more generic. Plus I think that intel neural compressor also has more SotA methods implemented when looking at post training quantization. For quantization aware training, it relies on pytorch's native QAT implementation

clementpoiret commented 10 months ago

@bilelomrani1 I made very simple callbacks for intel neural compressor, if you want to try :) https://github.com/clementpoiret/lightning-nc