ixaxaar / VardaGPT

Associative memory-enhanced GPT-2 model
338 stars 8 forks source link

VardaGPT

VardaGPT is a memory-enhanced GPT-2 model powered by Hugging Face Transformers and FAISS. Inspired by J.R.R. Tolkien's Silmarillion, VardaGPT aims to provide guidance and knowledge through its memory-augmented text generation capabilities.

TLDR - Training

The VardaGPTAssociative model combines GPT-2 with an associative memory to improve context retrieval. This repository includes a script to train this model on the WikiText-2 dataset.

Requirements

To install the required packages, you can use the following command:

pip install -r requirements.txt

Usage

To train the VardaGPTAssociative model on the WikiText-2 dataset, use the provided training script (train_varda_gpt_associative.py). You can customize the training settings by passing command-line arguments. Here's a basic example:

python train_varda_gpt_associative.py --epochs 5 --learning_rate 1e-4 --use_gpu

Available command-line arguments:

During training, the script will periodically print the training loss, validation loss, and elapsed time for each epoch, along with a progress bar for each training step.

After training, you can use the trained model for your specific use case, such as text generation or fine-tuning for a particular task.

Overview

Click me ```plantuml @startuml !define AWSPUML https://raw.githubusercontent.com/awslabs/aws-icons-for-plantuml/v14.0 actor User skinparam component { BackgroundColor<> LightSkyBlue BackgroundColor<> Plum BackgroundColor<> LightGreen BackgroundColor<> LightSalmon BackgroundColor<> LightCoral BorderColor Black FontName Arial } package "VardaGPT" { [Data Preparation]<> --> [FAISS Memory]<> [Data Preparation]<> --> [GPT-2 Adaptation]<> [FAISS Memory]<> --> [GPT-2 Adaptation]<> [GPT-2 Adaptation]<> --> [Training]<> [Training]<> --> [Inference]<> [FAISS Memory]<> --> [Inference]<> User --> [Data Preparation]<> : Dataset User --> [Inference]<> : Prompts } @enduml ```

overview

This diagram shows the main components of the VardaGPT project and their interactions. The Data Preparation component processes the dataset and feeds it to both the FAISS Memory Model and the GPT-2 Model Adaptation component. The FAISS Memory Model generates embeddings, which are used by the GPT-2 Model Adaptation component to create a modified GPT-2 model. The modified GPT-2 model is then trained and evaluated, and the final trained model is used in the Inference and Application component. The user provides the dataset and prompts for text generation.

Models

The associative memory model:

Click me ```plantuml @startuml rectangle "Input Vectors" as input #b3e0ff rectangle "Memory" as memory #f2d7b9 rectangle "Concatenated Input" as concatenated_input #f6e3c6 rectangle "Fully Connected Layer (fc)" as fc #e5ebf0 rectangle "GPT-2 Transformer" as transformer #c6e0b4 rectangle "GPT-2 LM Head" as lm_head #c9daf8 rectangle "Fully Connected Layer\n(fc_storable_vector)" as fc_storable_vector #c9daf8 rectangle "Fully Connected Layer\n(fc_store_decision)" as fc_store_decision #c9daf8 input -down-> memory : Perform search in memory memory -down-> concatenated_input : Concatenate search results with input vectors concatenated_input -down-> fc : Apply fully connected layer (fc) fc -down-> transformer : Pass through GPT-2 transformer transformer -down-> lm_head : Apply GPT-2 lm_head transformer -right-> fc_storable_vector : Apply fully connected layer (fc_storable_vector) transformer -right-> fc_store_decision : Apply fully connected layer (fc_store_decision) note right of fc_storable_vector: Calculate storable vector\n and store decision note right of fc_store_decision: Store the storable_vector in\n the associative memory if\n the store_decision is affirmative note bottom of lm_head: Return logits @enduml ```

model1

Click me ```plantuml @startuml title Forward Function !define Tensor(t,d) t + " (" + d + ")" !define DEVICE "device" actor "input_vectors" as input_vectors actor "memory_input" as memory_input note right of input_vectors Tensor: (batch_size, seq_len, embedding_dim) end note note right of memory_input Tensor (optional): (batch_size, seq_len, embedding_dim) end note input_vectors -> DEVICE memory_input -> DEVICE DEVICE -> "search(memory_input)" as search search --> "indices, distances" as search_result note right of search_result Tensors: indices: (batch_size, seq_len, num_search_results) distances: (batch_size, seq_len, num_search_results) end note search_result -> "get_all_embeddings()" as all_embeddings note right of all_embeddings Tensor: (memory_size, embedding_dim) end note all_embeddings -> "search_results" as search_results note right of search_results Tensor: (batch_size, seq_len, search_results_dim) end note search_results --> "concatenate(input_vectors, search_results)" as concatenated_input note right of concatenated_input Tensor: (batch_size, seq_len, embedding_dim + search_results_dim) end note concatenated_input --> "self.fc(concatenated_input)" as fc_output note right of fc_output Tensor: (batch_size, seq_len, embedding_dim) end note fc_output --> "self.gpt2_model.transformer(inputs_embeds=input_vectors)" as transformer_outputs transformer_outputs --> "hidden_states" as hidden_states note right of hidden_states Tensor: (batch_size, seq_len, embedding_dim) end note hidden_states --> "self.gpt2_model.lm_head(hidden_states)" as logits note right of logits Tensor: (batch_size, seq_len, vocab_size) end note hidden_states --> "self.fc_storable_vector(hidden_states)" as storable_vector note right of storable_vector Tensor: (batch_size, seq_len, memory_dim) end note hidden_states --> "self.fc_store_decision(hidden_states)" as store_decision note right of store_decision Tensor: (batch_size, seq_len, 1) end note hidden_states --> "self.fc_delete_decision(hidden_states)" as delete_decision note right of delete_decision Tensor: (batch_size, seq_len, num_search_results) end note hidden_states --> "self.fc_deletable_vector(hidden_states)" as deletable_vector note right of deletable_vector Tensor: (batch_size, seq_len, memory_dim) end note storable_vector --> "self.memory.add(storable_vector_to_store)" as add_memory deletable_vector --> "calculate L2 distances" as l2_distances note right of l2_distances Tensor: (batch_size, num_search_results) end note l2_distances --> "threshold comparison" as threshold_comparison note right of threshold_comparison Tensor (bool): (batch_size, num_search_results) end note threshold_comparison --> "self.memory.remove(indices_to_delete_flat)" as remove_memory logits --> "return logits" as return_logits @enduml ```

model

Training, Evaluation, and Fine-tuning Process

Click me ```plantuml @startuml skinparam activity { BackgroundColor LightSkyBlue BorderColor Black FontName Arial } start :Data Preparation; partition "FAISS Memory Model" { :Create FAISS Index; :Encode and Decode Text Data; :Test FAISS Index; } partition "GPT-2 Model Adaptation" { :Load Pre-trained GPT-2 Model; :Modify GPT-2 Architecture; :Define Custom Loss Function; } partition "Training" { :Train Adapted GPT-2 Model; :Save Model Checkpoints; } partition "Evaluation" { :Evaluate Model on Testing Set; :Calculate Metrics; } if (Fine-tuning needed?) then (Yes) partition "Fine-tuning" { :Adjust Hyperparameters; :Iterate Training and Evaluation; } endif partition "Inference and Application" { :Inference Function; :API or Interface; } stop @enduml ```

process

1. Data Preparation

2. GPT-2 Model Adaptation

3. Training

4. Evaluation

5. Fine-tuning (if necessary)

Prerequisites

Setup

  1. Clone the repository:
git clone https://github.com/yourusername/VardaGPT.git
cd VardaGPT
  1. Create and activate a virtual environment:
python -m venv venv
source venv/bin/activate
  1. Install the required libraries:
pip install -r requirements.txt

Directory Structure

Usage

Data Preparation

  1. Place your dataset in the data/ directory.
  2. Preprocess and split your dataset into training, validation, and testing sets using the provided scripts in src/.

Training

  1. Configure the training settings and model hyperparameters in the src/config.py file.
  2. Run the training script:
python src/train.py
  1. Monitor the training progress and save model checkpoints in the models/ directory.

Evaluation

  1. Evaluate the trained model on the validation and testing sets using the provided evaluation script:
python src/evaluate.py

Inference

  1. Use the provided inference script to generate text with the memory-enhanced GPT-2 model:
python src/inference.py --prompt "Your prompt text here"

Contributing

Feel free to contribute to this project by submitting pull requests or opening issues for bug reports and feature requests.

Code Formatting and Pre-commit

This project uses black, flake8, and mypy for Python code formatting and linting. We also use prettier to format JSON and Markdown files. The configuration for these tools is in the .pre-commit-config.yaml file.

Setup

  1. Install pre-commit if you haven't already:
pip install pre-commit
  1. Set up the git hooks:
pre-commit install

Using Pre-commit

Whenever you commit changes, the pre-commit hooks will automatically format your code and check for issues. If the hooks detect any problems, the commit will be aborted, and you'll see a list of issues that need to be fixed. Once you've resolved the issues, you can try committing again.

You can also run the pre-commit hooks manually on all files:

pre-commit run --all-files

Or run the hooks on specific files:

pre-commit run --files <file1> <file2>

By following this setup and using pre-commit hooks, you can ensure that the code in the repository remains consistently formatted and adheres to the project's coding standards.

License

This project is licensed under the MIT License.