MartaYang / SEGA

Codes for the WACV 2022 paper: "SEGA: Semantic Guided Attention on Visual Prototype for Few-Shot Learning"
MIT License
21 stars 1 forks source link

how to split base class in stage2 for 01_miniimagenet_stage2_SEGA_5W1S? #7

Open fikry102 opened 2 years ago

fikry102 commented 2 years ago

when I debug 01_miniimagenet_stage2_SEGA_5W1S, I get the following result : 【traincode.py --->>> def train_stage2(opt): 】

Knovel_ids.size() torch.Size([8, 5]) Kbase_ids.size() torch.Size([8, 59])

logit_query.size() torch.Size([8, 60, 64])

It seems 64 base classes are divided into 59 Kbase and 5 Knovel? And it does 64-way classification?

Could you please give some more details about these results? Thanks!

MartaYang commented 2 years ago

Hi, I think the explanation you are looking for can be found at the end of 'Training Procedure' in '3.3. Framework' of our paper:

More specifically, for each episode, we randomly sample N classes from the base classes Yb to act as “novel” classes, then sample K samples from each “novel” class to form a fake N-Way K-Shot support set. As shown in Figure 3, we can calculate N visual prototypes and enhance them using semantic guided attentions. Thus we get N classification weights which are used to replace the corresponding base classification weights (other weights are also enhanced by their own semantic attentions) in Cosine Classifier, and then perform classification and cross-entropy loss calculation.

In short, yes, we are always doing the 64-way classification (sample 5 classes to act as “novel” classes to generate their classification weights, and the weights of other 59 "base" classes are from training parameters of _SEGAhead.weightbase just like the first stage of training). By the way, this training strategy is from Dynamic-FSL.