uoguelph-mlrg / PROTAX-GPU

GPU-accelerated DNA barcode classification
8 stars 1 forks source link


alt text

A GPU-accelerated JAX-based implementation of PROTAX. Contains all code and experiments for PROTAX-GPU

To reproduce the BOLD 7.8M dataset experiments, PROTAX-GPU requires a NVIDIA GPU with at least 8GB VRAM and CUDA compute capability 6.0 or later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families.




Estimates the probability of each outcome in a taxonomic tree given a query barcode sequence compared to reference sequences.



experiments/        All experiments for the paper
lib/                C++/CUDA code for PROTAX-GPU
scripts/            Scripts for training and inference
├-- train.py        Trains a model with gradient descent 
└-- convert.py      Converts .TSV to PROTAX-GPU format
tests/              Unit tests 
├-- knn_jax/        JAX bindings for CUDA kernels
└-- protax/         Main PROTAX-GPU code


Linux yes yes n/a
Mac X86_64 yes n/a no
MAC (ARM) yes n/a no
Windows experimental experimental n/a


These instructions are for Linux and MacOS. Windows support is experimental.

1. Set up CUDA (required for GPU support)

IMPORTANT: PROTAX-GPU requires a full local installation of CUDA, including development headers and tools, due to its use of custom CUDA kernels.

Ensure that your system environment variables are correctly set up to point to your CUDA installation.

NOTE: While JAX offers an easier CUDA installation via pip wheels for some platforms, this method does not provide the full CUDA toolkit required by PROTAX-GPU. You must perform a local CUDA installation as described above.

2. Create & activate new environment (recommended)

Using Conda:

conda create -n [name] python=3.10
conda activate [name]

3. Install dependencies

Install CMake:

conda install -c conda-forge cmake

4. Install JAX and jaxlib

System Type Command
Linux CPU pip install "jax[cpu]"
Linux GPU pip install "jax[cuda122]"
MacOS CPU pip uninstall jax jaxlib && conda install -c conda-forge jax=0.4.26 jaxlib=0.4.23
MacOS GPU Not yet supported

NOTE: The GPU installation command assumes you have already installed CUDA 12.2 as per step 1.

5. Install PROTAX-GPU

Clone this repository:

git clone https://github.com/uoguelph-mlrg/PROTAX-GPU.git

Install requirements:

pip install -r requirements.txt

Finally, install PROTAX-GPU:

pip install .

This will install a package called protax in your environment.

NOTE: CUDA 11.2 is also supported by JAX, but support for it will be dropped in the future. As long as the JAX version supports CUDA 11.2 and is greater than or equal to 0.4.14, this should work. However, ensure that your local CUDA installation matches the version you're using with JAX.

TODO: Add Windows support instructions and check macOS GPU support

If you are on MacOS and facing installation issues, run the following commands

chmod +x ./scripts/fix_librhash.sh


Instructions for running PROTAX-GPU for inference and training.


Once you have a trained model, you can use the classify_file function to classify query sequences.

Run the sequence classification script:

python scripts/process_seqs.py [PATH_TO_QUERY_SEQUENCES] [PATH_TO_MODEL] [PATH_TO_TAXONOMY]


python scripts/process_seqs.py data/refs.aln models/params/model.npz models/ref_db/taxonomy37k.npz


Results are saved to pyprotax_results.csv


Run the script from the command line: You need to specify the paths to your training data and target data using the --train_dir and --targ_dir arguments, respectively.

python scripts/train_model.py --train_dir [PATH_TO_TRAINING_DATA] --targ_dir [PATH_TO_TARGET_DATA]

Command Line Arguments

Training Configuration The script uses the following training parameters:

The script uses models/params/model.npz as baseline and saves the trained model at models/params/m2.npz


To reproduce the BOLD dataset experiments, PROTAX-GPU requires an NVIDIA GPU with at least 16GB VRAM and CUDA compute capability 6.0 or later. This corresponds to GPUs in the NVIDIA Pascal, NVIDIA Volta™, NVIDIA Turing™, NVIDIA Ampere architecture, and NVIDIA Hopper™ architecture families.


The BOLD 7.8M dataset is available here. The dataset is not included in this repository due to its size.

The smaller FinPROTAX dataset is included in the models directory, sourced from here.


This project is licensed under the CC BY-NC-SA 4.0 License - see the LICENSE file for details.