qdrant / quaterion

Blazing fast framework for fine-tuning similarity learning models
https://quaterion.qdrant.tech/
Apache License 2.0
638 stars 45 forks source link

PyTorch Lightning's Stochastic Weight Averging callback causes infinite recursion #164

Closed andrewaf1 closed 2 years ago

andrewaf1 commented 2 years ago

When trying to use SWA to improve generalization for my triplet network, I came across the following error. I believe SWA creates copies of the model and I could see how this might trigger something like this. I have recently used SWA on a non-quaterion network, so I a pretty sure this is not a Lightning issue.

Traceback (most recent call last):
  File "/home/andrew/vake/siamese-classifier/train.py", line 155, in <module>
    train(
  File "/home/andrew/vake/siamese-classifier/train.py", line 85, in train
    Quaterion.fit(
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/quaterion/main.py", line 101, in fit
    trainer.fit(
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1174, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1493, in _call_setup_hook
    self._call_callback_hooks("setup", stage=fn)
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py", line 142, in setup
    self._average_model = deepcopy(pl_module)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 205, in _deepcopy_list
    append(deepcopy(a, memo))
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/andrew/anaconda3/envs/vake/lib/python3.9/copy.py", line 271, in _reconstruct
    if hasattr(y, '__setstate__'):
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/quaterion/eval/attached_metric.py", line 48, in __getattr__
    return getattr(self._metric, item)
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/quaterion/eval/attached_metric.py", line 48, in __getattr__
    return getattr(self._metric, item)
  File "/home/andrew/.cache/pypoetry/virtualenvs/imageclassifier-iYVta-lH-py3.9/lib/python3.9/site-packages/quaterion/eval/attached_metric.py", line 48, in __getattr__
    return getattr(self._metric, item)
  [Previous line repeated 978 more times]
RecursionError: maximum recursion depth exceeded
joein commented 2 years ago

Hi @andrewaf1 ! Yes, it is a bug from our side, sorry for inconvenience. I have already made a PR to fix this, we will merge it soon

joein commented 2 years ago

@andrewaf1 try v0.1.29 please

andrewaf1 commented 2 years ago

@joein Thanks! It now trains without error, although when I load my model to do inference I get vasty reduced accuray. Will examine further tomorrow or Monday to determine if I am doing something wrong.

joein commented 2 years ago

@andrewaf1

The problem may lie in the amount of data on which you calculate your metrics.

If you look at the results of AttachedMetric and then compare them to the metrics you get across the entire dataset, there is likely to be a significant difference.

This is because AttachedMetric only sees a fixed-size batch, but during inference there is the entire dataset (and more difficult examples).

I'd recommend using Evaluator to get a better sense of metrics you will obtain on inference.

Example of usage Evaluator can be found in our Q&A demo.

This issue seems to be fixed, I am going to close it soon. Please feel free to open a new issue or a discussion if you find anything odd or need help.