ericzhao28 / ALLS

Active learning under label shift.
4 stars 3 forks source link

Questions about the code #2

Open vicmax opened 3 years ago

vicmax commented 3 years ago

Thank you for sharing the code of this interesting work.

As you mentioned, the experiments can be reproduced by python3 -m alsa.main.replicate. However, it is really not easy to follow the code. When I began with alsa/main/replicate.py, I have some questions about the implementation details:

Build the dataset

I noticed that in your implementation, when you build the dataset in replicate.py, we call the function ActiveDataset.divide(). In this function divide(), I can see that the training set is firstly split into warmstart set (warmup in paper) and online set (unlabeled pool in paper), i.e., train -> warmup + online, then shift them to the desired distribution like uniform or Dirichlet. Then we can see that a subset of online set is split into initial set, i.e., online -> initial + online (see code).

My questions: So what's the role of this initial set? Which part does it corresponds to in the paper? It seems that when we train the model for the first time (see code), we use warmstart + initial as training set.

About the label shift estimation and the importance weighting in training

After the first training process, we estimate the label shift between the labeled data (warmstart+ initial + online[labeled_pointers]) and the test data, and record it in the dataset (see code). My questions are: (1) It seems you access the test data during training, and is it an issue? (2) Based on my understanding, shouldn't we estimate the label shift between the medial distribution and the labeled distribution?

Evaluation process

The evaluation process (e.g., here and here in replicate.py) seems to multiply the label weights when we predict the probabilities in get_preds() of alls/alsa/nets/common.py, i.e., p=p*label_weights. Why we do this?

in python functions iwal_bootstrap and iwal_bootstrap_old in sampling.py

Based on my understanding of these two function, apart from the first model h_0, you train other several models and organize them as a committee (you set the version space of the committee as 8 Resnet-18 models in Experiment of Section 5). And then you select the unlabeled data point with a disagreement measure based on this committee. In these two python functions, you compute the predictions of all models on each unlabeled data point (variable all_probs), and then calculated the disagreement measure as all_probs -> probs_disagreement -> sample_probs (see the code between Line251 and Line289 as example), finally, you decided whether to selected a sample or not based on the probability threshold sample_probs .

My question is: To obtain all_probs, it seems you didn't normalize the model output after exp() in all situations. Is it a mistake?

If the condition here doesn't hold, the all_probs is not normalized and it is not a probability, and you will not get a wrong result for the probability thresholding in querying since all_probs is not a probability.

After reading the code, I summarize the main steps as following:

Step 1: training a model with importance wight=1 Step 2: estimate the label shift between test data and labeled data and update the importance weight. Step 3: train other models and form them into a committee, sampling a certain number of the unlabeled data based on the disagreement from this committee. (i.e., IWAL) Step 4: finetune the model members in this committee with previous data and recently labeled data with the importance weight estimated from the last time. Step 5: repeat Step2-Step4, to train the model with new importance weight, record the evaluation results before break.

My question is: I didn't find any details about the role of medial distribution proposed in this paper, except for using the medial distribution to measure the shift magnitude in measure_composition.py. Maybe I missed some details? I would appreciate it if you could explain a bit more about the details and the high-level insights?

Looking forward to your reply!

ericzhao28 commented 3 years ago

Hi, thank you for your interest. To answer your questions:

Build the dataset

There are three groups of data in our active learning setting: warmstart data, online data, and test data. If we want to estimate the label shift between the online data and test data, we need to randomly label some datapoints from the online data (the warmstart data will not suffice as they may be sampled from a different distribution). We refer to these randomly labeled datapoints as the "initial set". The remainder of the online data is then referred to as the "online set" (this naming is indeed confusing, sorry). For the purposes of replicating our results, you can ignore the initial set.

Label shift estimation

We do access the test data to estimate label shift. However, we do not access the labels of said test data. The explanation for this is described in Page 3 (arxiv version of our paper), between Eqs 5 & 6. We need to access unlabeled data from the test domain in order to estimate label shift. We do not need to directly estimate the label shift between the medial and the labeled distribution. We only need to sample something similar to the medial distribution from unlabeled data, which we do with proxy labels. See Page 4's left column.

Evaluation

Our code supports posterior regularization (mentioned in Pages 8-9). Rather than using our importance weights during learning, we can apply them during inference time to manipulate the logits predicted by our network. We find this to be more stable when importance weights are large.

Sampling

In the specific line of code you referenced, we didn't need normalization after taking the exp because we only took the argmax of the variable p. Line 288 is unnecessary, I probably just added it out of habit.

Main steps

Those main steps seem correct.

Medial distribution

The medial distribution is mainly used around line 300 of sampling.py, where we perform subsampling. When diversify = guess, the medial distribution is uniform. When diversify = hypsubguess, the medial distribution is the square root medial shown in Figure 5 of the paper.

I hope that helps and I'm happy to answer other questions. You can also reach me over email if you'd like to chat over the phone. Thanks again for your interest!

vicmax commented 3 years ago

Thank you for your reply! I have some questions about the sampling process.

About the role of medial distribution

Take the function iwal_boostrap in sampling.py as an example. When diversify=="guess" (between L423 and L433), it seems you are trying to scale the variable all_probs with the pseudo labels ys obtained in L408. But, since L446, you just select samples from the online unlabeled data based on sample_probs which was already obtained in L416.

So it seems that the medial distribution doesn't have **any influence on the sampling process.** Same for diversify==subguess of square root medial distribution, and for function iwal_boostrap_old.

Did I have a misunderstanding on this? If so, could you please provide any detailed explanations?

ericzhao28 commented 3 years ago

iwal_bootstrap refers to the IWAL algorithm, a baseline which does not use the medial distribution. The choice of medial distribution (configured using the variable diversify) does have an influence on general_sampling, which implements our proposed method.

vicmax commented 3 years ago

OK, it seems that I misunderstood it. I thought iwal_bootstrap refers to your proposed method.

Other questions in general_sampling():