lwneal / counterfactual-open-set

Counterfactual Image Generation
82 stars 23 forks source link

Baseline Classifier uses a strange loss? #5

Open KevLuo opened 3 years ago

KevLuo commented 3 years ago

When the baseline classifier is trained, it appears to not use cross-entropy loss? On line 127 of training.py, we see:

# Classify real examples into the correct K classes with hinge loss
classifier_logits = netC(images)
errC = F.softplus(classifier_logits * -labels).mean()

According to the code, the labels here have shape b x k where b=batch_size and k=number of closed classes where a 1 indicates correct class and -1 indicates incorrect class. The * operator in pytorch refers to elementwise multiplication so if we denote x = classifier_logits * -labels we have that element ij of x is the logit for the jth class of example i multiplied by the negative of the label for class j of example i. The final error is the mean of all the elements in the resulting matrix after the softplus is applied elementwise on x.

But doesn't this formulation produce strange behavior for the loss? Since the mean is taken of the resulting matrix, the logit for each of the classes of a given example is treated independently of all other logits for the given example. In other words, the logit for the correct class may be high and the logit for an incorrect class may be high. Under normal cross entropy, the loss is large but with this implementation the loss is low in one case and high in the other. If such logits are used for open set detection, this could lead to unnecessarily worse performance. In addition, this loss function introduces heavy class imbalance due to the fact that for a given example, far more classes will be labeled -1 than +1 which may also artificially damage the baseline classifier because each of these classes individually contributes equal loss signal.

So I have two questions: 1) Am I correct in concluding that the baseline does not use cross-entropy loss? 2) If so, why does the baseline use this alternative loss? Seems like this could lead to some optimization difficulties or make the baseline less robust than if normal cross entropy was used?