ContinualAI / avalanche

Avalanche: an End-to-End Library for Continual Learning based on PyTorch.
http://avalanche.continualai.org
MIT License
1.78k stars 290 forks source link

GDumb memory update #457

Closed AntonioCarta closed 3 years ago

AntonioCarta commented 3 years ago

GDumb does not remove samples when the number of classes increases.

AndreaCossu commented 3 years ago

I investigated more into this. First, the callback implemented by GDumb has to be renamed to the new after_train_dataset_adaptation. However, there is still a problem when enumerate(strategy.experience.dataset) is called within the plugin. This raises the error RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32].

Since the dataset is not a DataLoader, what is the proper way to loop through it?

AntonioCarta commented 3 years ago

When you iterate over the dataset you should get the single examples instead of mini-batches. Apart from that, I think everything else should work.

AndreaCossu commented 3 years ago

Yes, I expected that, too. Instead, I got that error. Apparently, you cannot enumerate an experience.dataset object. Maybe @lrzpellegrini can help us on this.

lrzpellegrini commented 3 years ago

That's strange, the dataset should be an iterable object. What's the error raised when enumerating it?

AndreaCossu commented 3 years ago

This is the error raised with GDumb (after the callback name modification) when using scenario = SplitFMnist(5). I noticed that this error is not raised with SplitMNIST, though.

Traceback (most recent call last):
  File "/home/cossu/avalanche/examples/ewc_mnist.py", line 92, in <module>
    main(args)
  File "/home/cossu/avalanche/examples/ewc_mnist.py", line 63, in main
    strategy.train(experience)
  File "/home/cossu/avalanche/avalanche/training/strategies/base_strategy.py", line 249, in train
    self.train_exp(exp, eval_streams, **kwargs)
  File "/home/cossu/avalanche/avalanche/training/strategies/base_strategy.py", line 272, in train_exp
    self.after_train_dataset_adaptation(**kwargs)
  File "/home/cossu/avalanche/avalanche/training/strategies/base_strategy.py", line 400, in after_train_dataset_adaptation
    p.after_train_dataset_adaptation(self, **kwargs)
  File "/home/cossu/avalanche/avalanche/training/plugins/gdumb.py", line 41, in after_train_dataset_adaptation
    for i, (pattern, target_value, _) in enumerate(dataset):
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
    return TupleTLabel(manage_advanced_indexing(
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
    single_element = single_element_getter(int(single_idx))
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 1035, in _get_single_item
    return self._process_pattern(self._dataset[idx], idx)
  File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torch/utils/data/dataset.py", line 272, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
    return TupleTLabel(manage_advanced_indexing(
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
    single_element = single_element_getter(int(single_idx))
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 659, in _get_single_item
    return self._process_pattern(self._dataset[idx], idx)
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
    return TupleTLabel(manage_advanced_indexing(
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
    single_element = single_element_getter(int(single_idx))
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 1035, in _get_single_item
    return self._process_pattern(self._dataset[idx], idx)
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 207, in __getitem__
    result = super().__getitem__(idx)
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 184, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
    return TupleTLabel(manage_advanced_indexing(
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
    single_element = single_element_getter(int(single_idx))
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 659, in _get_single_item
    return self._process_pattern(self._dataset[idx], idx)
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
    return TupleTLabel(manage_advanced_indexing(
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
    single_element = single_element_getter(int(single_idx))
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 659, in _get_single_item
    return self._process_pattern(self._dataset[idx], idx)
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 669, in _process_pattern
    pattern, label = self._apply_transforms(pattern, label)
  File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 680, in _apply_transforms
    pattern = self.transform(pattern)
  File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 67, in __call__
    img = t(img)
  File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 226, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 284, in normalize
    tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]

Process finished with exit code 1
AndreaCossu commented 3 years ago

Ok, the last error has nothing to do with GDumb and it appears to be a bug in SplitFMnist. I will create a new issue to track it and close this as soon as GDumb is ready.