Open vicmax opened 3 years ago
Hi, thank you for your interest. To answer your questions:
There are three groups of data in our active learning setting: warmstart data, online data, and test data. If we want to estimate the label shift between the online data and test data, we need to randomly label some datapoints from the online data (the warmstart data will not suffice as they may be sampled from a different distribution). We refer to these randomly labeled datapoints as the "initial set". The remainder of the online data is then referred to as the "online set" (this naming is indeed confusing, sorry). For the purposes of replicating our results, you can ignore the initial set.
We do access the test data to estimate label shift. However, we do not access the labels of said test data. The explanation for this is described in Page 3 (arxiv version of our paper), between Eqs 5 & 6. We need to access unlabeled data from the test domain in order to estimate label shift. We do not need to directly estimate the label shift between the medial and the labeled distribution. We only need to sample something similar to the medial distribution from unlabeled data, which we do with proxy labels. See Page 4's left column.
Our code supports posterior regularization (mentioned in Pages 8-9). Rather than using our importance weights during learning, we can apply them during inference time to manipulate the logits predicted by our network. We find this to be more stable when importance weights are large.
In the specific line of code you referenced, we didn't need normalization after taking the exp because we only took the argmax of the variable p
. Line 288 is unnecessary, I probably just added it out of habit.
Those main steps seem correct.
The medial distribution is mainly used around line 300 of sampling.py, where we perform subsampling. When diversify = guess
, the medial distribution is uniform. When diversify = hypsubguess
, the medial distribution is the square root medial shown in Figure 5 of the paper.
I hope that helps and I'm happy to answer other questions. You can also reach me over email if you'd like to chat over the phone. Thanks again for your interest!
Thank you for your reply! I have some questions about the sampling process.
Take the function iwal_boostrap
in sampling.py as an example. When diversify=="guess"
(between L423 and L433), it seems you are trying to scale the variable all_probs
with the pseudo labels ys
obtained in L408. But, since L446, you just select samples from the online unlabeled data based on sample_probs
which was already obtained in L416.
So it seems that the medial distribution doesn't have **any influence on the sampling process.** Same for diversify==subguess
of square root medial distribution, and for function iwal_boostrap_old
.
Did I have a misunderstanding on this? If so, could you please provide any detailed explanations?
iwal_bootstrap
refers to the IWAL algorithm, a baseline which does not use the medial distribution. The choice of medial distribution (configured using the variable diversify
) does have an influence on general_sampling
, which implements our proposed method.
OK, it seems that I misunderstood it. I thought iwal_bootstrap
refers to your proposed method.
Other questions in general_sampling()
:
diversify== "none"
or "guess"
or "overguess"
respectively correspond in your paper? The names are really confusing.iterative_iw=True
and train_iw=True
. Is it right?
Thank you for sharing the code of this interesting work.
As you mentioned, the experiments can be reproduced by
python3 -m alsa.main.replicate
. However, it is really not easy to follow the code. When I began withalsa/main/replicate.py
, I have some questions about the implementation details:Build the dataset
I noticed that in your implementation, when you build the dataset in replicate.py, we call the function ActiveDataset.divide(). In this function
divide()
, I can see that the training set is firstly split into warmstart set (warmup in paper) and online set (unlabeled pool in paper), i.e.,train -> warmup + online
, then shift them to the desired distribution like uniform or Dirichlet. Then we can see that a subset of online set is split into initial set, i.e.,online -> initial + online
(see code).My questions: So what's the role of this initial set? Which part does it corresponds to in the paper? It seems that when we train the model for the first time (see code), we use warmstart + initial as training set.
About the label shift estimation and the importance weighting in training
After the first training process, we estimate the label shift between the labeled data (
warmstart+ initial + online[labeled_pointers]
) and the test data, and record it in the dataset (see code). My questions are: (1) It seems you access the test data during training, and is it an issue? (2) Based on my understanding, shouldn't we estimate the label shift between the medial distribution and the labeled distribution?Evaluation process
The evaluation process (e.g., here and here in
replicate.py
) seems to multiply thelabel weights
when we predict the probabilities inget_preds()
ofalls/alsa/nets/common.py
, i.e.,p=p*label_weights
. Why we do this?in python functions
iwal_bootstrap
andiwal_bootstrap_old
in sampling.pyBased on my understanding of these two function, apart from the first model h_0, you train other several models and organize them as a committee (you set the version space of the committee as 8 Resnet-18 models in Experiment of Section 5). And then you select the unlabeled data point with a disagreement measure based on this committee. In these two python functions, you compute the predictions of all models on each unlabeled data point (variable
all_probs
), and then calculated the disagreement measure asall_probs -> probs_disagreement -> sample_probs
(see the code between Line251 and Line289 as example), finally, you decided whether to selected a sample or not based on the probability thresholdsample_probs
.My question is: To obtain
all_probs
, it seems you didn't normalize the model output afterexp()
in all situations. Is it a mistake?If the condition here doesn't hold, the
all_probs
is not normalized and it is not a probability, and you will not get a wrong result for the probability thresholding in querying sinceall_probs
is not a probability.After reading the code, I summarize the main steps as following:
Step 1: training a model with
importance wight=1
Step 2: estimate the label shift between test data and labeled data and update theimportance weight
. Step 3: train other models and form them into a committee, sampling a certain number of the unlabeled data based on the disagreement from this committee. (i.e.,IWAL
) Step 4: finetune the model members in this committee with previous data and recently labeled data with theimportance weight
estimated from the last time. Step 5: repeat Step2-Step4, to train the model with newimportance weight
, record the evaluation results before break.My question is: I didn't find any details about the role of medial distribution proposed in this paper, except for using the medial distribution to measure the shift magnitude in measure_composition.py. Maybe I missed some details? I would appreciate it if you could explain a bit more about the details and the high-level insights?
Looking forward to your reply!