google-research / disentanglement_lib

disentanglement_lib is an open-source library for research on learning disentangled representations.
Apache License 2.0
1.38k stars 205 forks source link

Lookup index for the trained models #22

Open colehurwitz opened 4 years ago

colehurwitz commented 4 years ago

Does a lookup index exist for the 10'800 pretrained disentanglement_lib modules? I wanted to quickly reproduce specific experiments (e.g. btcvae, beta=4, dsprites), but I am not sure how to find that pretrained model.

Thanks!

ryx19th commented 4 years ago

I have an observation: take the last group of experiments on 3dshapes as an example, there are 1800 in total (10800~12599). As mentioned in the paper, they evaluated 6 methods, with 6 regularization strengths (i.e., hyperparameter values) each. And they also tried 50 different random seeds. So there have been 6650=1800. I guess this should be the overall layout.

As for the specific configurations, you may run some first and get the printed config as anchors, and calculate the index of your desired one. (Or actually you can just modify the code to quit right after printing the config and then traverse all experiements in minutes to get the full list)

bonheml commented 3 years ago

This is a late answer but I think it may be useful for other people looking for this kind of feature. You can get the list of model configurations using the get_config function in the sweep.py files of each study. Now, from this, you can create a dataframe to search the configs more easily.

For example for unsupervised_study_v1:

from disentanglement_lib.config.unsupervised_study_v1.sweep import get_config
import pandas as pd

configs = get_config()
print(configs[12])  # give the config of model 12

df = pd.dataFrame(configs)

# Get the beta_tc_vae models using dsprites and a beta param of 4.
model = df['model.name'] == "beta_tc_vae"
dataset = df['dataset.name'] == "dsprites_full"
beta = df["beta_tc_vae.beta"] == 4.0

# Give all the possible model configs unused parameters excluded (e.g., annealed_vae.gamma)
print(df[model & dataset & beta].dropna(axis=1))

# Give the list of model ids with this configuration
print(df.index[model & dataset & beta])