Closed nlgranger closed 3 years ago
For the questions above:
Sorry for the delay and thank you for looking into this.
loss = F.cross_entropy(predictions, targets, reduction='none')
weights = class_weights[targets]
loss *= weights
return loss.sum() / loss.shape[0]
Whereas yours is equivalent to
loss = F.cross_entropy(predictions, targets, reduction='none')
weights = class_weights[targets]
loss *= weights
return loss.sum() / weights.sum()
Hi,
I'm really sorry for the late reply due to some personal affairs. After reading your latest reply, I understood your concern. Therefore, I looked into the original repository and would like to figure out whether I made a mistake. As I found the original loss function implementation here:
def get_loss(self, logits, labels, pre_cal_weights):
# calculate the weighted cross entropy according to the inverse frequency
class_weights = tf.convert_to_tensor(pre_cal_weights, dtype=tf.float32)
one_hot_labels = tf.one_hot(labels, depth=self.config.num_classes)
weights = tf.reduce_sum(class_weights * one_hot_labels, axis=1)
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=one_hot_labels)
weighted_losses = unweighted_losses * weights
output_loss = tf.reduce_mean(weighted_losses)
return output_loss
I think you are correct. It first calculates the unweighted losses (same as F.cross_entropy(predictions, targets, reduction='none')
), then weights them using the pre-compute class weights (same as weights = class_weights[targets], loss *= weights
). However, in the final step, the original implementation apply tf.reduce_mean(weighted_losses)
, while my implementation use another term loss.sum() / weights.sum()
.
Really thanks for your great catching. I will modify and push the program if I am not that busy. However, I think this is not a big issue because loss.shape[0]
and weights.sum()
are just constant values. The former is the same as batch size and the later is a pre-computed value.
Close issue for long inactive.
Firstly, thank you for your work, an up-to-date pytorch implementation of RandLA is really nice to have. This is not a bug report but rather a series of questions I had when I started implementing RandLA before I found this repository.