Closed AntonioCarta closed 3 years ago
I investigated more into this. First, the callback implemented by GDumb
has to be renamed to the new after_train_dataset_adaptation
.
However, there is still a problem when enumerate(strategy.experience.dataset)
is called within the plugin.
This raises the error RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]
.
Since the dataset is not a DataLoader, what is the proper way to loop through it?
When you iterate over the dataset you should get the single examples instead of mini-batches. Apart from that, I think everything else should work.
Yes, I expected that, too. Instead, I got that error.
Apparently, you cannot enumerate an experience.dataset
object. Maybe @lrzpellegrini can help us on this.
That's strange, the dataset should be an iterable object. What's the error raised when enumerating it?
This is the error raised with GDumb (after the callback name modification) when using scenario = SplitFMnist(5)
. I noticed that this error is not raised with SplitMNIST, though.
Traceback (most recent call last):
File "/home/cossu/avalanche/examples/ewc_mnist.py", line 92, in <module>
main(args)
File "/home/cossu/avalanche/examples/ewc_mnist.py", line 63, in main
strategy.train(experience)
File "/home/cossu/avalanche/avalanche/training/strategies/base_strategy.py", line 249, in train
self.train_exp(exp, eval_streams, **kwargs)
File "/home/cossu/avalanche/avalanche/training/strategies/base_strategy.py", line 272, in train_exp
self.after_train_dataset_adaptation(**kwargs)
File "/home/cossu/avalanche/avalanche/training/strategies/base_strategy.py", line 400, in after_train_dataset_adaptation
p.after_train_dataset_adaptation(self, **kwargs)
File "/home/cossu/avalanche/avalanche/training/plugins/gdumb.py", line 41, in after_train_dataset_adaptation
for i, (pattern, target_value, _) in enumerate(dataset):
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
return TupleTLabel(manage_advanced_indexing(
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
single_element = single_element_getter(int(single_idx))
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 1035, in _get_single_item
return self._process_pattern(self._dataset[idx], idx)
File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torch/utils/data/dataset.py", line 272, in __getitem__
return self.dataset[self.indices[idx]]
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
return TupleTLabel(manage_advanced_indexing(
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
single_element = single_element_getter(int(single_idx))
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 659, in _get_single_item
return self._process_pattern(self._dataset[idx], idx)
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
return TupleTLabel(manage_advanced_indexing(
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
single_element = single_element_getter(int(single_idx))
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 1035, in _get_single_item
return self._process_pattern(self._dataset[idx], idx)
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 207, in __getitem__
result = super().__getitem__(idx)
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 184, in __getitem__
return self.dataset[self.indices[idx]]
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
return TupleTLabel(manage_advanced_indexing(
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
single_element = single_element_getter(int(single_idx))
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 659, in _get_single_item
return self._process_pattern(self._dataset[idx], idx)
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 306, in __getitem__
return TupleTLabel(manage_advanced_indexing(
File "/home/cossu/avalanche/avalanche/benchmarks/utils/dataset_utils.py", line 320, in manage_advanced_indexing
single_element = single_element_getter(int(single_idx))
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 659, in _get_single_item
return self._process_pattern(self._dataset[idx], idx)
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 669, in _process_pattern
pattern, label = self._apply_transforms(pattern, label)
File "/home/cossu/avalanche/avalanche/benchmarks/utils/avalanche_dataset.py", line 680, in _apply_transforms
pattern = self.transform(pattern)
File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 67, in __call__
img = t(img)
File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 226, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/home/cossu/miniconda3/envs/avalanche-env/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 284, in normalize
tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]
Process finished with exit code 1
Ok, the last error has nothing to do with GDumb
and it appears to be a bug in SplitFMnist
. I will create a new issue to track it and close this as soon as GDumb is ready.
GDumb
does not remove samples when the number of classes increases.