Surreal-GAN is a semi-supervised representation learning method that is designed to identify disease-related heterogeneity among the patient group. Surreal-GAN parses complex disease-related imaging patterns into low-dimensional representation indices (r-indices), with each dimension indicating the severity of one relatively homogeneous imaging pattern.
The key point of the Surreal-GAN model is modeling disease as a continuous process and learning infinite transformation directions from CN to PT, with each direction capturing a specific combination of patterns and severity. The idea is realized by learning one transformation function, f, which takes both normal data and a continuous latent variable as inputs and outputs synthesized-PT data whose distribution is indistinguishable from that of real PT data. As shown in the schematic diagram, several different regularizations were introduced to further guide the transformation function. An inverse function g, jointly trained with f, is used for deriving R-indices after the training process.
The fundamental framework of the basic Surreal-GAN model (Yang et al. 2022) inherently encourages independence among the derived R-indices, limiting its applicability in various scenarios. Therefore, starting from Surreal-GAN 0.1.1, we have implemented an updated Surreal-GAN model (Yang et al. 2023). This enhanced Surreal-GAN incorporates a correlation structure among the R-indices within a reduced representation latent space. This modification allows the model to capture interactions among multiple underlying neuropathological processes.
We strongly encourage the users to upgrade to version 0.1.1.
Copyright (c) 2016 University of Pennsylvania. All rights reserved. See https://www.cbica.upenn.edu/sbia/software/license.html
We highly recommend that users install Anaconda3 on their machines. After installing Anaconda3, Smile-GAN can be used following this procedure:
We recommend that users use the Conda virtual environment:
$ conda create --name surrealgan python=3.8
Activate the virtual environment
$ conda activate surrealgan
Install SurrealGAN from PyPi:
$ pip install SurrealGAN
The main function of SurrealGAN basically takes two Panda dataframes as data inputs: data and covariate (optional). Columns with the names 'participant_id' and 'diagnosis' must exist in both dataframes. Some conventions for the group label/diagnosis: -1 represents healthy control (CN) and 1 represents patient (PT); categorical variables, such as sex, should be encoded as numbers: Female for 0 and Male for 1, for example.
Example for data:
participant_id diagnosis ROI1 ROI2 ...
subject-1 -1 325.4 603.4
subject-2 1 260.5 580.3
subject-3 -1 326.5 623.4
subject-4 1 301.7 590.5
subject-5 1 293.1 595.1
subject-6 1 287.8 608.9
Example for covariate
participant_id diagnosis age sex ...
subject-1 -1 57.3 0
subject-2 1 43.5 1
subject-3 -1 53.8 1
subject-4 1 56.0 0
subject-5 1 60.0 1
subject-6 1 62.5 0
We offer a toy dataset in the folder of SurrealGAN/dataset.
import pandas as pd
from SurrealGAN.Surreal_GAN_representation_learning import repetitive_representation_learning
train_data = pd.read_csv('train_roi.csv')
covariate = pd.read_csv('train_cov.csv')
output_dir = "PATH_OUTPUT_DIR"
npattern = 3
final_saving_epoch = 25000
max_epoch = 26000
## two important hyperparamters
lam = 0.2
gamma = 6
To ensure optimal performance and flexibility in Surreal-GAN representation learning, users can adjust the following hyper-parameters according to their specific needs:
batch_size
lam
cluster_loss
in the training objective function.lam
values between 0.05 and 1.6. Use the results yielding the highest Rindices-Correlation as indicated in the output file.gamma
change_loss
in the training objective function.gamma
values between 0.1 and 8. Select the results with the highest Rindices-Correlation as returned in the output file.saving_freq
saving_freq
based on the sample size and final_saving_epoch
, (Recommend: 1/40-1/20 of final_saving_epoch
). final_saving_epoch
final_saving_epoch
based on their specific datasets. While a larger value of final_saving_epoch
can ensure reaching the optimal epoch, it comes at the expense of longer training times. As outlined in the section Main function for Model Training, the optimal epoch is determined at the conclusion of the training process. If users train all repetitions in parallel, real-time monitoring of agreements among models allows for early stopping, even before reaching the set final_saving_epoch
. However, when training repetitions sequentially, it is advisable for users to carefully choose final_saving_epoch
to ensure the attainment of the best model agreements (measured by Rindices-Correlation), (Recommend: 1500000*(300/patient_sample_size)).Rindices-Correlation is used as the metric for measuring agreements between results and selecting the optimal model. Specifically, it equals the means of the following two measurements:
Dimension-correlation: With M-dimensional R-indices derived by two different models defined as r1 and r2, Dimension-correlation is defined as the average of M Pearson’s correlations for all dimensions: $$\frac{1}{M}(\sum_{i=1}^M \rho(r_i^1,r_i^2))$$.
Difference-correlation: With M-dimensional R-indices derived by two different models defined as r1 and r2, pattern-diff-agr-index is defined as the average of M(M-1)/2 Pearson’s correlations for all pairs of dimensions: $$\frac{2}{M(M-1)}(\sum{i=1}^M \sum{j=i+1}^M \rho(r_i^1-r_j^1,r_i^2-r_j^2))$$
repetition_number = 30 # number of repetitions (at least 20 repetition\
is need to give the most reliable and reproducible result)
data_fraction = 1 # fraction of data used in each repetition
repetitive_representation_learning(train_data, npattern, repetition_number, data_fraction, final_saving_epoch, output_dir, \
lr = 0.0008, batchsize=120, verbose=False, lipschitz_k=0.5, covariate= None, start_repetition=0, lam=lam, gamma = gamma)
The repetitive_representation_learning
function is the cornerstone of representation learning using Surreal-GAN. It performs the repetitive training process with a user-defined number of repetitions.
saving_freq
epoch, preceding the final_saving_epoch
while adhering to set criteria.repetition_numer
to be greater than 20). Given the potentially prolonged duration of the repetitive training process on a standard desktop computer, the function provides an option for early stopping and later resumption. Users can set stop_repetition
as an early stopping point and start_repetition
to be the starting repetition index.
repetition_number
), the function will train all repetitions sequentially. This may result in an extended training time.repetition_number
, allows users to run multiple repetitions in parallel, particularly effective on HPC clusters.saving_freq
epoch and are saved in a real-time updated CSV file named 'model_agreements.csv'.
Upon completion of all repetitions, the function automatically saves a CSV file and returns the same dataframe. The CSV file contains the following information:
gamma
and lam
.model_dir = 'PATH_TO_SAVED_MODEL' #the path to the final selected model (the one returned by function "repetitive_representation_learning")
r_indices = apply_saved_model(model_dir, application_data, epoch ,application_covariate=None)
apply_saved_model is a function used for deriving R-indices for new patient data using a previously saved model.
The function returns R-indices of PT data following the order of PT in the provided dataframe.
If you use this package for research, please cite the following paper:
@inproceedings{yang2022surrealgan,
title={Surreal-{GAN}:Semi-Supervised Representation Learning via {GAN} for uncovering heterogeneous disease-related imaging patterns},
author={Zhijian Yang and Junhao Wen and Christos Davatzikos},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=nf3A0WZsXS5}
}