Code for training and benchmarking morphology appropriate representation learning methods.
To install and use this software, you need:
First, clone this repository.
git clone https://github.com/AllenCell/benchmarking_representations
cd benchmarking_representations
Create a virtual environment.
conda create --name br python=3.10
conda activate br
Depending on your GPU set-up, you may need to set the CUDA_VISIBLE_DEVICES
environment variable.
To achieve this, you will first need to get the Universally Unique IDs for the GPUs and then set CUDA_VISIBLE_DEVICES
to some/all of those (a comma-separated list), as in the following examples.
Example 1
export CUDA_VISIBLE_DEVICES=0,1
Example 2: Using one partition of a MIG partitioned GPU
export CUDA_VISIBLE_DEVICES=MIG-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
Next, install all required packages
pip install -r requirements1.txt
pip install -r requirements2.txt
pip install -e .
For pdm
users, follow these installation steps instead.
Q: When installing dependencies, pytorch fails to install with the following error message.
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus
A: You may need to configure the CUDA_VISIBLE_DEVICES
environment variable.
To run the models, you must set the CYTODL_CONFIG_PATH
environment variable to point to the br/configs
folder.
Check that your current working directory is the benchmarking_representations
folder, then run the following command (this will last for only the duration of your shell session).
export CYTODL_CONFIG_PATH=$PWD/configs/
[!NOTE]
Ensure to download all the data in the same folder where the repo was cloned!
├── data
│ ├── preprocessing
│ │ ├── image_preprocessing
│ │ │ ├── config
│ │ │ │ ├── config.yaml <- Data config for image processing workflow
Then follow the installation steps to run the snakefile located in
├── data
│ ├── preprocessing
│ │ ├── image_preprocessing
│ │ │ ├── Snakefile <- Image preprocessing workflow. Combines alignment, masking, registration
For point cloud preprocessing for punctate structures, update data paths and run the workflow in
├── data
│ ├── preprocessing
│ │ ├── pc_preprocessing
│ │ │ ├── punctate_cyto.py <- Point cloud sampling from raw images for punctate structures here
For SDF preprocessing for polymorphic structures, update data paths and run the workflows in
├── data
│ ├── preprocessing
│ │ ├── sdf_preprocessing
│ │ │ ├── image_sdfs.py <- Create 32**3 resolution SDF images
│ │ │ ├── pc_sdfs.py <- Sample point clouds from 32**3 resolution SDF images
In all cases, create a single cell manifest (e.g. csv, parquet) for each dataset with a column corresponding to final processed paths, and create a split column corresponding to train/test/validation split.
SINGLE_CELL_DATASET_PATH
) and the column in the manifest for appropriate input modality (SDF_COLUMN
/SEG_COLUMN
/POINTCLOUD_COLUMN
/IMAGE_COLUMN
) in each datamodule yaml files. e.g. for PCNA data these yaml files are located here -├── configs
│ ├── data
│ │ ├── pcna
│ │ │ ├── image.yaml <- Datamodule for PCNA images
│ │ │ ├── pc.yaml <- Datamodule for PCNA point clouds
│ │ │ ├── pc_intensity.yaml <- Datamodule for PCNA point clouds with intensity
│ │ │ ├── pc_intensity_jitter.yaml <- Datamodule for PCNA point clouds with intensity and jitter
├── configs
│ ├── experiment
│ │ ├── cellpack
│ │ │ ├── image_equiv.yaml <- Rotation invariant image model experiment
│ │ │ ├── pc_equiv.yaml <- Rotation invariant point cloud model experiment
Here is an example of training a rotation invariant point cloud model
python src/br/models/train.py experiment=cellpack/pc_equiv ++mlflow.experiment_name=[EXPERIMENT_NAME] ++mlflow.run_name=[RUN_NAME]
Override parts of the experiment config via command line or manually in the configs. For example, to train a classical model, run
python src/br/models/train.py experiment=cellpack/pc_equiv model=pc/classical_earthmovers_sphere ++mlflow.experiment_name=[EXPERIMENT_NAME] ++mlflow.run_name=[RUN_NAME]
├── br
│ ├── notebooks
│ │ ├── fig2_cellpack.ipynb <- Reproduce Fig 2 cellPACK synthetic data results
│ │ ├── fig3_pcna.ipynb <- Reproduce Fig 3 PCNA data results
│ │ ├── fig4_other_punctate.ipynb <- Reproduce Fig 4 other puntate structure data results
│ │ ├── fig5_npm1.ipynb <- Reproduce Fig 5 npm1 data results
│ │ ├── fig6_other_polymorphic.ipynb <- Reproduce Fig 6 other polymorphic data results
│ │ ├── fig7_drug_data.ipynb <- Reproduce Fig 7 drug data results
├── LICENSE
├── README.md <- The top-level README for developers using this project.
├── br <- Source code for use in this project.
│ ├── data
│ │ ├── preprocessing <- Preprocessing scripts to generate point clouds and SDFs
│ │ ├── get_datamodules.py <- Get final list of datamodules per dataset
│
│ ├── models <- Training and inference scripts
│ │ ├── train.py <- Training script using cyto_dl given an experiment config
│ │ ├── predict_model.py <- Inference functions
│ │ ├── save_embeddings.py <- Save embeddings using inference functions
│ │ ├── load_models.py <- Load trained models based on checkpoint paths
│ │ ├── compute_features.py <- Compute multi-metric features for each model based on saved embeddings
│
│ ├── features <- Metrics for benchmarking each model
│ │ ├── archetype.py <- Archetype analysis functions
│ │ ├── classification.py <- Test set classification accuracies using logistic regression classifiers
│ │ ├── outlier_compactness.py <- Intrinsic dimensionality calculation and outlier classification
│ │ ├── reconstruction.py <- Functions for reconstruction viz across models
│ │ ├── regression.py <- Linear regression test set r^2
│ │ ├── rotation_invariance.py <- Sensitivity to four 90 degree rotations in embedding space
│ │ ├── plot.py <- Polar plot viz across metrics
│
├── configs <- Training configs for each experiment
│ ├── callbacks <- e.g. Early stopping, model checkpoint etc
│ ├── data <- Datamodules for each dataset
│ ├── experiment <- training config for an experiment combining data, models, logger
│ ├── model <- config for Pytorch Lightning model
│ ├── trainer <- trainer parameters for Pytorch Lightning
│ ├── logger <- Choice of logger to save results
│ ├── hydra <- Hydra params to perform experiment sweeps
│
├── notebooks <- Jupyter notebooks.
│
├── pointcloudutils
│ ├── pointcloudutils
│ │ ├── datamodules <- Custom datamodules
│ │ │ ├── cellpack.py <- CellPACK data specific datamodule
│ │ ├── networks <- Custom networks
│ │ │ ├── simple_inner.py <- Inner product decoder for SDF reconstruction
│
├── pyproject.toml <- makes project pip installable (pip install -e .) so br can be imported