zhijian-yang / SurrealGAN

MIT License
24 stars 8 forks source link

Surreal-GAN

Surreal-GAN is a semi-supervised representation learning method that is designed to identify disease-related heterogeneity among the patient group. Surreal-GAN parses complex disease-related imaging patterns into low-dimensional representation indices (r-indices), with each dimension indicating the severity of one relatively homogeneous imaging pattern.

The key point of the Surreal-GAN model is modeling disease as a continuous process and learning infinite transformation directions from CN to PT, with each direction capturing a specific combination of patterns and severity. The idea is realized by learning one transformation function, f, which takes both normal data and a continuous latent variable as inputs and outputs synthesized-PT data whose distribution is indistinguishable from that of real PT data. As shown in the schematic diagram, several different regularizations were introduced to further guide the transformation function. An inverse function g, jointly trained with f, is used for deriving R-indices after the training process.

The fundamental framework of the basic Surreal-GAN model (Yang et al. 2022) inherently encourages independence among the derived R-indices, limiting its applicability in various scenarios. Therefore, starting from Surreal-GAN 0.1.1, we have implemented an updated Surreal-GAN model (Yang et al. 2023). This enhanced Surreal-GAN incorporates a correlation structure among the R-indices within a reduced representation latent space. This modification allows the model to capture interactions among multiple underlying neuropathological processes.

We strongly encourage the users to upgrade to version 0.1.1.

image info

License

Copyright (c) 2016 University of Pennsylvania. All rights reserved. See https://www.cbica.upenn.edu/sbia/software/license.html

Installation

We highly recommend that users install Anaconda3 on their machines. After installing Anaconda3, Smile-GAN can be used following this procedure:

We recommend that users use the Conda virtual environment:

$ conda create --name surrealgan python=3.8

Activate the virtual environment

$ conda activate surrealgan

Install SurrealGAN from PyPi:

$ pip install SurrealGAN

Input structure

The main function of SurrealGAN basically takes two Panda dataframes as data inputs: data and covariate (optional). Columns with the names 'participant_id' and 'diagnosis' must exist in both dataframes. Some conventions for the group label/diagnosis: -1 represents healthy control (CN) and 1 represents patient (PT); categorical variables, such as sex, should be encoded as numbers: Female for 0 and Male for 1, for example.

Example for data:

participant_id    diagnosis    ROI1    ROI2 ...
subject-1       -1         325.4   603.4
subject-2            1         260.5   580.3
subject-3           -1         326.5   623.4
subject-4            1         301.7   590.5
subject-5            1         293.1   595.1
subject-6            1         287.8   608.9

Example for covariate

participant_id    diagnosis    age    sex ...
subject-1       -1         57.3   0
subject-2        1         43.5   1
subject-3           -1         53.8   1
subject-4            1         56.0   0
subject-5            1         60.0   1
subject-6            1         62.5   0

Example

We offer a toy dataset in the folder of SurrealGAN/dataset.

import pandas as pd
from SurrealGAN.Surreal_GAN_representation_learning import repetitive_representation_learning

train_data = pd.read_csv('train_roi.csv')
covariate = pd.read_csv('train_cov.csv')

output_dir = "PATH_OUTPUT_DIR"
npattern = 3
final_saving_epoch = 25000
max_epoch = 26000

## two important hyperparamters
lam = 0.2
gamma = 6

Important Hyper-parameters

To ensure optimal performance and flexibility in Surreal-GAN representation learning, users can adjust the following hyper-parameters according to their specific needs:

batch_size

lam

gamma

saving_freq

final_saving_epoch

Rindices-Correlation

Rindices-Correlation is used as the metric for measuring agreements between results and selecting the optimal model. Specifically, it equals the means of the following two measurements:

Main function for Model Training

repetition_number = 30  # number of repetitions (at least 20 repetition\
       is need to give the most reliable and reproducible result)
data_fraction = 1 # fraction of data used in each repetition
repetitive_representation_learning(train_data, npattern, repetition_number, data_fraction, final_saving_epoch, output_dir, \
        lr = 0.0008, batchsize=120, verbose=False, lipschitz_k=0.5, covariate= None, start_repetition=0, lam=lam, gamma = gamma)

The repetitive_representation_learning function is the cornerstone of representation learning using Surreal-GAN. It performs the repetitive training process with a user-defined number of repetitions.

Process Description

Model Saving

Optimal Saving Epoch and Repetition

Parallel vs. Sequential Training

Given the potentially prolonged duration of the repetitive training process on a standard desktop computer, the function provides an option for early stopping and later resumption. Users can set stop_repetition as an early stopping point and start_repetition to be the starting repetition index.

Monitoring Training Process

Output File

Upon completion of all repetitions, the function automatically saves a CSV file and returns the same dataframe. The CSV file contains the following information:

Model Application to out-of-sample Participants

model_dir = 'PATH_TO_SAVED_MODEL' #the path to the final selected model (the one returned by function "repetitive_representation_learning")
r_indices = apply_saved_model(model_dir, application_data, epoch ,application_covariate=None)

apply_saved_model is a function used for deriving R-indices for new patient data using a previously saved model.

Input Data

Output

The function returns R-indices of PT data following the order of PT in the provided dataframe.

Citation

If you use this package for research, please cite the following paper:

@inproceedings{yang2022surrealgan,
title={Surreal-{GAN}:Semi-Supervised Representation Learning via {GAN} for uncovering heterogeneous disease-related imaging patterns},
author={Zhijian Yang and Junhao Wen and Christos Davatzikos},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=nf3A0WZsXS5}
}