GT-RIPL / Continual-Learning-Benchmark

Evaluate three types of task shifting with popular continual learning algorithms.
MIT License
510 stars 92 forks source link

GEM uses memory data in active training set and updates memory twice #1

Closed srcdc closed 4 years ago

srcdc commented 5 years ago

The GEM implementation here inherits from Naive_Rehearsal and calls super(GEM, self).learn_batch(train_loader, val_loader). It therefore uses the learn_batch method of Naive_Rehearsal which uses memory data together with the new data to compute the original gradients before checking conflicts with any memory gradients. From my understanding that is not the intent of the original paper and might affect the training results.

Additionally, Naive_Rehearal's learn_batch method already updates the memory and task_count, but GEM does this a second time once the call returns.

yenchanghsu commented 5 years ago

Thanks for pointing out the differences! You are right that the GEM class should not inherit the Naive_Rehearsal class with its current way. I will fix this and check how it impacts the performance.

shivamsaboo17 commented 4 years ago

hi I was curious to know if this issue is still fixed?

srcdc commented 4 years ago

As far as I can tell neither the code nor the results were updated.

shivamsaboo17 commented 4 years ago

I see. So I made a simple change - added a boolean argument to learn_batch function of NaiveRehearsal using it by passing correction=True: This also prevents repeated update of the memory. Is this implementation correct or have I missed something?

def learn_batch(self, train_loader, val_loader=None, correction=False):
        if correction:
            super(Naive_Rehearsal, self).learn_batch(train_loader, val_loader)
yenchanghsu commented 4 years ago

Sorry for the late update. The duplicated memory operation in GEM has been removed. Its performance is the same. In the case of split-MNIST-incremental-class, the average accuracy is 92.07+-0.32 (new) versus 92.20+-0.12 (previous result).

I also added the GEM_orig to match the process in the original GEM paper, which does not include the memory data to calculate the new gradients. The GEM_orig is slightly lower than our implementation (91.32+-0.11 versus 92.07+-0.32).

srcdc commented 4 years ago

Great. Thanks for the update!