Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.06k stars 3.36k forks source link

Allow calling fit multiple times with Quantization (QAT) #7427

Closed t-vi closed 3 years ago

t-vi commented 3 years ago

🐛 Bug

When using the QuantizationAwareTraining() callback, one cannot call Trainer.fit twice.

To Reproduce

Call Trainer.fit twice on any model gives an exception (which is pretty qpaque, I might add)

Expected behavior

Resumes training

Additional context

I imagine here is a tradeoff between two goals:

  1. Have fit return a model that's ready for inference,
  2. don't catapult yourself irreversibly out of the option to continue training.

With Quantization Aware Training's conversion step at the end (moving from fake quantization for QAT to quantized layers), we have to choose one. Currently, the QAT hook converts, so it has 1 but not 2.

To my mind, the conversion step is "preparing for export / inference" rather than part of training, so I would suggest to drop it from the fitting part.

t-vi commented 3 years ago

To avoid the confusion, calling fit twice means something like

dm = BoringDataModule()
model = BoringModel()
t = pl.Trainer()
trainer = pl.Trainer(gpus=1, max_epochs=1, check_val_every_n_epoch=1, callbacks=[QuantizationAwareTraining()])
trainer.fit(model, datamodule=dm)
trainer.fit(model, datamodule=dm)

With BoringModel, you I get

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type        | Params
----------------------------------------
0 | layer   | Linear      | 66    
1 | quant   | QuantStub   | 0     
2 | dequant | DeQuantStub | 0     
----------------------------------------
66        Trainable params
0         Non-trainable params
66        Total params
0.000     Total estimated model params size (MB)

Epoch 0:  50%|█████     | 64/128 [00:00<00:00, 389.96it/s, loss=0.197, v_num=40]
Validating: 0it [00:00, ?it/s]
Epoch 0: 100%|██████████| 128/128 [00:00<00:00, 546.09it/s, loss=0.197, v_num=40]
Epoch 0: 100%|██████████| 128/128 [00:00<00:00, 540.57it/s, loss=0.197, v_num=40]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-47-27d928f7fc20> in <module>
    152 trainer = pl.Trainer(gpus=1, max_epochs=1, check_val_every_n_epoch=1, callbacks=[QuantizationAwareTraining()])
    153 trainer.fit(model, datamodule=dm)
--> 154 trainer.fit(model, datamodule=dm)

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    456         )
    457 
--> 458         self._run(model)
    459 
    460         assert self.state.stopped

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    713         self.call_setup_hook(model)  # allow user to setup lightning_module in accelerator environment
    714         self.call_configure_sharded_model(model)  # allow user to setup in model sharded environment
--> 715         self.accelerator.setup(self, model)  # note: this sets up self.lightning_module
    716 
    717         # ----------------------------

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/accelerators/gpu.py in setup(self, trainer, model)
     39         self.set_nvidia_flags()
     40         torch.cuda.set_device(self.root_device)
---> 41         return super().setup(trainer, model)
     42 
     43     def on_train_start(self) -> None:

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/accelerators/accelerator.py in setup(self, trainer, model)
     90         self.setup_training_type_plugin(self.training_type_plugin, model)
     91         if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
---> 92             self.setup_optimizers(trainer)
     93         self.setup_precision_plugin(self.precision_plugin)
     94 

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/accelerators/accelerator.py in setup_optimizers(self, trainer)
    372         if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
    373             return
--> 374         optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
    375             trainer=trainer, model=self.lightning_module
    376         )

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in init_optimizers(self, trainer, model)
    188 
    189     def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule'):
--> 190         return trainer.init_optimizers(model)
    191 
    192     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/optimizers.py in init_optimizers(self, model)
     32     def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
     33         self._lightning_optimizers = None
---> 34         optim_conf = model.configure_optimizers()
     35         if optim_conf is None:
     36             rank_zero_warn(

<ipython-input-47-27d928f7fc20> in configure_optimizers(self)
     98 
     99     def configure_optimizers(self):
--> 100         optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
    101         lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
    102         return [optimizer], [lr_scheduler]

/usr/local/lib/python3.9/dist-packages/torch/optim/sgd.py in __init__(self, params, lr, momentum, dampening, weight_decay, nesterov)
     67         if nesterov and (momentum <= 0 or dampening != 0):
     68             raise ValueError("Nesterov momentum requires a momentum and zero dampening")
---> 69         super(SGD, self).__init__(params, defaults)
     70 
     71     def __setstate__(self, state):

/usr/local/lib/python3.9/dist-packages/torch/optim/optimizer.py in __init__(self, params, defaults)
     47         param_groups = list(params)
     48         if len(param_groups) == 0:
---> 49             raise ValueError("optimizer got an empty parameter list")
     50         if not isinstance(param_groups[0], dict):
     51             param_groups = [{'params': param_groups}]

ValueError: optimizer got an empty parameter list

If you use a more interesting model (M5 from the PyTorch Audio Keyword detection) that still has parameters after quantization, you get something like

NotImplementedError                       Traceback (most recent call last)
<ipython-input-62-a5d40beafe36> in <module>
      2 trainer = pl.Trainer(gpus=1, max_epochs=1, check_val_every_n_epoch=1, callbacks=[QuantizationAwareTraining()])
      3 trainer.fit(model, datamodule=dm)
----> 4 trainer.fit(model, datamodule=dm)

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    456         )
    457 
--> 458         self._run(model)
    459 
    460         assert self.state.stopped

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    754 
    755         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 756         self.dispatch()
    757 
    758         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    795             self.accelerator.start_predicting(self)
    796         else:
--> 797             self.accelerator.start_training(self)
    798 
    799     def run_stage(self):

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    805         if self.predicting:
    806             return self.run_predict()
--> 807         return self.run_train()
    808 
    809     def _pre_training_routine(self):

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    840             self.progress_bar_callback.disable()
    841 
--> 842         self.run_sanity_check(self.lightning_module)
    843 
    844         self.checkpoint_connector.has_trained = False

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
   1105 
   1106             # run eval step
-> 1107             self.run_evaluation()
   1108 
   1109             self.on_sanity_check_end()

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, on_epoch)
    960                 # lightning module methods
    961                 with self.profiler.profile("evaluation_step_and_end"):
--> 962                     output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
    963                     output = self.evaluation_loop.evaluation_step_end(output)
    964 

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/evaluation_loop.py in evaluation_step(self, batch, batch_idx, dataloader_idx)
    172             model_ref._current_fx_name = "validation_step"
    173             with self.trainer.profiler.profile("validation_step"):
--> 174                 output = self.trainer.accelerator.validation_step(args)
    175 
    176         # capture any logged information

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, args)
    224 
    225         with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
--> 226             return self.training_type_plugin.validation_step(*args)
    227 
    228     def test_step(self, args: List[Union[Any, int]]) -> Optional[STEP_OUTPUT]:

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
    159 
    160     def validation_step(self, *args, **kwargs):
--> 161         return self.lightning_module.validation_step(*args, **kwargs)
    162 
    163     def test_step(self, *args, **kwargs):

<ipython-input-61-234454f3c7b4> in validation_step(self, batch, batch_idx)
     57     def validation_step(self, batch, batch_idx):
     58         inp, label = batch
---> 59         pred = self(inp)
     60         loss = torch.nn.functional.cross_entropy(pred, label)
     61         acc = self.val_accuracy(pred.softmax(dim=-1), label)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1013         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1014                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1015             return forward_call(*input, **kwargs)
   1016         # Do not call functions when jit is used
   1017         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/quantization.py in wrapper(data)
     49             quant_cb._forward_calls += 1
     50             data = model.quant(data)
---> 51         data = func(data)
     52         # apply custom trigger
     53         if _quant_run:

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/quantization.py in wrapper(data)
     66     def wrapper(data) -> Any:
     67         data = model.quant(data)
---> 68         data = func(data)
     69         data = model.dequant(data)
     70         return data

<ipython-input-61-234454f3c7b4> in forward(self, x)
     45 
     46     def forward(self, x):
---> 47         return self.model(x)
     48 
     49     def training_step(self, batch, batch_idx):

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1013         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1014                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1015             return forward_call(*input, **kwargs)
   1016         # Do not call functions when jit is used
   1017         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-61-234454f3c7b4> in forward(self, x)
     17 
     18     def forward(self, x):
---> 19         x = self.conv1(x)
     20         x = F.relu(self.bn1(x))
     21         x = self.pool1(x)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1013         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1014                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1015             return forward_call(*input, **kwargs)
   1016         # Do not call functions when jit is used
   1017         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/quantized/modules/conv.py in forward(self, input)
    328             input = F.pad(input, _reversed_padding_repeated_twice,
    329                           mode=self.padding_mode)
--> 330         return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
    331 
    332     @classmethod

NotImplementedError: Could not run 'quantized::conv1d' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv1d' is only available for these backends: [QuantizedCPU, BackendSelect, Named, InplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, Tracer, Autocast, Batched, VmapMode].

QuantizedCPU: registered at ../aten/src/ATen/native/quantized/cpu/qconv.cpp:876 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
InplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:60 [backend fallback]
AutogradOther: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
AutogradXLA: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:51 [backend fallback]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradMLC: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:55 [backend fallback]
Tracer: fallthrough registered at ../torch/csrc/jit/frontend/tracer.cpp:1027 [backend fallback]
Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:250 [backend fallback]
Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
edenlightning commented 3 years ago

@Borda mind taking a look?

Borda commented 3 years ago

@t-vi how about adding an extra argument to the callback which would skip this quantization in on_fit_end and in such case user needs to call it separately and anytime he/she wants...

t-vi commented 3 years ago

So there are two practically irreversible transitions: Model -> QAT-Model -> Quantized model. Yes, it would make sense to have something to skip the second one and offer it as an extra step. It's really a bit sad that you cannot unquantize a model, but it seems that the PyTorch quantization didn't really appreciate the philosophy of PyTorch as much as I would have wished.

Borda commented 3 years ago

it seems that the PyTorch quantization didn't really appreciate the philosophy of PyTorch as much as I would have wished.

kind of way around could be replacing the quantized modules with the initial ones and will value from the quantized ones... what do you think, does it worse to make this kind of reverse engineering?

t-vi commented 3 years ago

I think this reversal could be made (you'd have the rounding errors, but hey), but the problem of doing it outside PyTorch that I see is operator coverage.

Borda commented 3 years ago

the problem of doing it outside PyTorch that I see is operator coverage

what do you mean by operator coverage?

t-vi commented 3 years ago

I mean that if there are quantized operators we don't handle (correctly) then we can't unquantize the model. I'd have to think a bit about how much of a problem I'd expect it to be, it might be not as bad as I thought it was. (Maybe I just wrote too much about operator coverage in those quantization tutorials... :slightly_smiling_face: ).