google-research / disentanglement_lib

disentanglement_lib is an open-source library for research on learning disentangled representations.
Apache License 2.0
1.38k stars 205 forks source link

Weak Dataset Sampling - Possible Discrepancy Between Implementation and Arxiv Paper? #31

Open nmichlo opened 4 years ago

nmichlo commented 4 years ago

I have been examining the implementation of generating the weakly supervised dataset from the paper Weakly-Supervised Disentanglement Without Compromises, however I think there is a discrepancy between the function simple_dynamics in disentanglement_lib/methods/weak/train_weak_lib and what is described in Section 5 from version 3 the Arxiv paper.

Examined Function

https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/methods/weak/train_weak_lib.py#L41-L57

Excerpt & Corresponding Code

The following is an excerpt from the Experimental Setup subsection from Section 5 of the paper Weakly-Supervised Disentanglement Without Compromises, split into sections that I assume should correspond to the above code:

  1. To create data sets with weak supervision from the existing disentanglement data sets, we first sample from the discrete z according to the ground-truth generative model (1)–(2).

    https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/methods/weak/train_weak_lib.py#L41-L42

  2. Then, we sample k factors of variation that should not be shared by the two images...

    https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/methods/weak/train_weak_lib.py#L44-L49

  3. ... and re-sample those coordinates to obtain z˜. This ensures that each image pair differs in at most k factors of variation.

    https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/methods/weak/train_weak_lib.py#L50-L54

  4. For k we consider the range from 1 to d − 1. This last setting corresponds to the case where all but one factor of variation are re-sampled. We study both the case where k is constant across all pairs in the data set and where k is sampled uniformly in the range [d − 1] for every training pair (k = Rnd in the following). Unless specified otherwise, we aggregate the results for all values of k.

Problem With 2?

From the above excerpt there seems to be a problem with line: https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/methods/weak/train_weak_lib.py#L48-L49

In particular the expression random_state.choice([1, k_observed]). Instead of keeping k fixed half of the time k will be set to 1.

I may be misunderstanding things from the excerpt, but to me this seems odd that this is happening.

Fix?

Based on this, should lines 48 and 49 not be the following?

index_list = random_state.choice(z.shape[1], k_observed, replace=False)

Problem With 4?

Based on the following excerpt it seems as though factors in the sampled pairs should always differ.

...We study both the case where k is constant across all pairs in the data set...

https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/methods/weak/train_weak_lib.py#L52-L53

However, based on lines 52-53 this is not the case. There is a chance for the re-sampled factor to be the same. It is not guaranteed to be different.

This probability of being the same will only increase if the ground truth dimensionality/size of that factor is small.

Fix?

Sampling with the original value for the particular differing z removed from the range.

Untested possible code for 1 input factor:

choices = set(range(ground_truth_data.factors_num_values[index])) - {z[0, index]}
z[0, index] = np.random.choice(choices)