aimagelab / mammoth

An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning
MIT License
582 stars 105 forks source link

Doubt related to transform in the buffer.py #8

Closed bhattg closed 3 years ago

bhattg commented 3 years ago

Hello,

I have a question related to the der.py and buffer.py, specifically related to the applied transforms for the data augmentation. Following are the transformation used in der for split cifar 10

Compose(
    ToPILImage()
    Compose(
    RandomCrop(size=(32, 32), padding=4)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2615))
))

While we store the samples in the buffer, we always save the non-augmented inputs and the corresponding logits, as shown in the following snippet.

self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)

    def add_data(self, examples, labels=None, logits=None, task_labels=None):
        if not hasattr(self, 'examples'):
            self.init_tensors(examples, labels, logits, task_labels)

        for i in range(examples.shape[0]):
            index = reservoir(self.num_seen_examples, self.buffer_size)
            self.num_seen_examples += 1
            if index >= 0:
                self.examples[index] = examples[i].to(self.device)
                if labels is not None:
                    self.labels[index] = labels[i].to(self.device)
                if logits is not None:
                    self.logits[index] = logits[i].to(self.device)
                if task_labels is not None:
                    self.task_labels[index] = task_labels[i].to(self.device)

Now, when we call get_all_elements or get_data

    def get_data(self, size: int, transform: transforms=None) -> Tuple:
        """
        Random samples a batch of size items.
        :param size: the number of requested items
        :param transform: the transformation to be applied (data augmentation)
        :return:
        """
        if size > min(self.num_seen_examples, self.examples.shape[0]):
            size = min(self.num_seen_examples, self.examples.shape[0])

        choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
                                  size=size, replace=False)
        if transform is None: transform = lambda x: x
        ret_tuple = (torch.stack([transform(ee.cpu())
                            for ee in self.examples[choice]]).to(self.device),)
        for attr_str in self.attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str)
                ret_tuple += (attr[choice],)

        return ret_tuple

It applies the set of transformations, on the non_augmented examples.

Now my question is the following- when we request for the elements from the buffer, would the transformation on the top of non_augmented images still be the same, which generated the corresponding logits.? Since the transforms are stochastic, (crop/flip), it seems to give different example in contrast with the original transformed input, which generated the logits.

It will be great if you can answer my query, thanks!!

mbosc commented 3 years ago

Yes, this is correct. As we point out in section 4.1 in our paper, this is a choice that amounts to applying an additional consistency regularisation objective. This is key in getting better performance, not only for DER, but also for all other rehearsa-based methods (especially ER).

bhattg commented 3 years ago

Thanks a lot for the lightning-fast response. I will have a look at the papers you have cited in sec 4.1. Although I haven't read those papers, intuitively what you said makes a lot of sense. However, when mathematically writing down the loss term with the following notations

Let f(x, \theta) be the network giving logits in R^d. Let x' denotes the applied transformation, and \theta' denotes any new parameter after some gradient steps. Therefore, the loss at the moment can be described as

Err(x, x', \theta, \theta') = || f(x', \theta') - f(x, \theta) ||

This can be upper bounded by -
Err(x, x', \theta, \theta') \leq || f(x', \theta') - f(x, \theta') || + || f(x, \theta') - f(x, \theta) || = (Consistency loss at \theta') + Dark experience distillation of logits

Therefore, is there any other way of mathematically showing that minimizing Err will lead to the minimization of the sum on the RHS?

Once again, thanks for the quick response!

mbosc commented 3 years ago

I'm not sure I understand the question correctly: how can one derive that the first expression (what we apply) is exactly an upper bound of the RHS of the second expression (the sum of consistency and DER loss)?

I do not really think that we should see the combination of the two losses as a straightforward sum, but rather as a kind of composition: we are computing DER loss on top of the application of consistency regularisation.

Since we are assuming that the model should learn to disregard augmentations when producing a response, DER loss is all the more valid when distinct augmentations are at play.

bhattg commented 3 years ago

Thanks for the response. I might not have properly conveyed my question. Quoting the line from sec 4.1 -

"It is worth noting that combining data augmentation with our regularization objective enforces an implicit consistency loss, which aligns predictions for the same example subjected to small data transformations."

The first expression Err(x, x', \theta, \theta') = || f(x', \theta') - f(x, \theta) || is optimized (in addition to the cross entropy loss on the current examples) during the training. While the RHS is an upper bound on the || f(x', \theta') - f(x, \theta) ||. That is,

Err(x, x', \theta, \theta') \leq || f(x', \theta') - f(x, \theta') || + || f(x, \theta') - f(x, \theta) ||

"Since we are assuming that the model should learn to disregard augmentations when producing a response"

My question is from a mathematical point of view. While in practice it happens that the model learns to disregard the augmentation, is there a mathematical way to show that starting from optimizing Err (in addition to cross-entropy loss on the current task) also leads to the minimization of the consistency loss?