Elastic weight consolidation technique for incremental learning.
Use this API if you dont want your neural network to forget previously learnt tasks while doing transfer learning or domain adaption!
The experiment is done as follow:
EWC | MNIST | Fashion-MNIST |
---|---|---|
Yes | 70.27 | 81.88 |
No | 48.43 | 86.69 |
from elastic_weight_consolidation import ElasticWeightConsolidation
# Build a neural network of your choice and pytorch dataset for it
# Define a criterion class for new task and pass it as shown below
ewc = ElasticWeightConsolidation(model, crit, lr=0.01, weight=0.1)
# Training procedure
for input, target in dataloader:
ewc.forward_backward_update(input, target)
ewc.register_ewc_params(dataset, batch_size, num_batches_to_run_for_sampling)
# Repeat this for each new task and it's corresponding dataset