CanPeng123 / FSCIL_ALICE

28 stars 6 forks source link

Clarification regarding classifier and class-prototype embedding used in different stages #8

Closed prachigarg23 closed 1 year ago

prachigarg23 commented 1 year ago

Hi @CanPeng123 , I wanted to confirm the backbone and classifier used to obtain features in the different training and testing stages during base and incremental sessions.

Base session:

  • Training: Train entire backbone, projection layers, and angular FC.

  • Testing: in class NCMValidation, line 93. embedding = self.model(data) invokes the model_factory > model > self.encoder which includes projection layers but not the angular FC. This output is used as the embedding for NCM evaluation.

Incremental session:

  • Training/Testing: main_inc_ncm.py, line 295. embedding = model.encode(data) this is the ResNet backbone with average pooling and flattening of the encoder output.

I'm confused because this model was initialized with a FC layer (lines 185-193 in main_inc_ncm.py) but that FC doesn't seem to be getting used anywhere in the incremental step train/eval process.

In the paper Section 4.3 mentions removing projection head and angular penalty classification head after base step training, but I cannot find the exact encoder output being used to compute the class-wise prototypes in incremental steps.

CanPeng123 commented 1 year ago

Hi, The main_base.py file is used for base session training. During the base session, the feature extractor, projection layer, and augmented classifier are utilized and trained.

The main_inc_ncm.py file is used for the following incremental session. During incremental sessions, only the feature extractor trained from the base session is used. Then class-wise average prototypes are utilized for classification. Thus, yes the FC layer in main_inc_ncm.py is not used. I created it for other experiments but it is not used in the final ALICE method. The class-wise prototypes are created by the function calculate_avg_feature_for_each_cls() in line 270 main_inc_ncm.py.