h3lio5 / episodic-lifelong-learning

Implementation of "Episodic Memory in Lifelong Language Learning"(NeurIPS 2019) in Pytorch
MIT License
56 stars 8 forks source link
catastrophic-forgetting continual-learning deep-learning experience-replay lifelong-learning natural-language-processing nips-2019 nips-paper pytorch

Episodic Memory in Lifelong Language Learning

Code for the paper: Episodic Memory in Lifelong Language Learning(Arxiv:1905.12926) for the text classification setup.

Introduction

The ability to continuously learn and accumulate knowledge throughout a lifetime and reuse it effectively to adapt to a new problem quickly is a hallmark of general intelligence. State-of-the-art machine learning models work well on a single dataset given enough training examples, but they often fail to isolate and reuse previously acquired knowledge when the data distribution shifts (e.g., when presented with a new dataset)—a phenomenon known as catastrophic forgetting. In this paper, the authors introduce a lifelong language learning setup where a model needs to learn from a stream of text examples without any dataset identifier. Specificaly, they propose an episodic memory model that performs sparse experience replay and local adaptation to mitigate catastrophic forgetting in this setup. Experiments on the text classification and question answering tasks demonstrate that the episodic memory module is a crucial building block of general linguistic intelligence.

Model

Main components of the model are:

1.Setup Instructions and Dependencies

You may setup the repository on your local machine by either downloading it or running the following line on terminal.

git clone https://github.com/h3lio5/episodic-lifelong-learning.git

All dependencies required by this repo can be downloaded by creating a virtual environment with Python 3.7 and running

python3 -m venv .env
source .env/bin/activate
pip install -r requirements.txt

2.Directory description

Download the data manually from link or to download the data automatically, run -

python data_download.py

The dataset files that you need to download and extract are as follows:

Place the train and test csv files after renaming them according to their corresponding dataset names in the original_data subdirectory of the data/ directory.For instance, place training set of amazon under the original_data directory under the name.
The repository should like this after downloading and placing the data in the appropriate folders

root
├── README.md  
├── data
│   ├── ordered_data
│   │   ├── test
│   │   └── train
│   └── original_data
│       ├── test
│       │   ├── agnews.csv
│       │   ├── amazon.csv
│       │   ├── dbpedia.csv
│       │   ├── yahoo.csv
│       │   └── yelp.csv
│       └── train
│           ├── agnews.csv
│           ├── amazon.csv
│           ├── dbpedia.csv
│           ├── yahoo.csv
│           └── yelp.csv
├── data_loader.py
├── main.py
├── models
│   ├── MbPAplusplus.py
│   └── baselines
│       ├── MbPA.py
│       ├── enc_dec.py
│       └── replay.py
├── preprocess.py
└── requirements.txt

3.Preprocessing

To preprocess and create ordered datasets, run

python preprocess.py

4.Training Model from Scratch

To train your own model from scratch, run

python main.py --mode train --epochs "any_number" --order "1/2/3/4"

5.Inference

To test the model, run

python main.py --mode test --model_path "path_to_checkpoint" --memory_path "path_to_replay_memory"

References