microsoft / otdd

Optimal Transport Dataset Distance
MIT License
156 stars 48 forks source link

Unsupervised dataset RunTime error (setting the debiased_loss=False and ignore_target_labels=True) #29

Open agiannoul opened 1 year ago

agiannoul commented 1 year ago

For unsupervised datasets and setting the parameters debiased_loss=False ignore_source_labels=True ignore_target_labels=True

the code in file distance.py lines 300-303 will run:

if (targets2 is None) or self.ignore_target_labels:
    reindex_start = len(self.V1) if (self.loss == 'sinkhorn' and self.debiased_loss) else True
    X, Y_infer, Y_true = self._load_infer_labels(D2, classes2, reindex=True, reindex_start=reindex_start)
    self.targets2 = targets2 = Y_infer - reindex_start

Which results in RuntimeError: Subtraction, the - operator, with a bool tensor is not supported, because of Y_infer - reindex_start (I think that in line 301 instead of True maybe the correct is 0 ?

changes to: reindex_start = len(self.V1) if (self.loss == 'sinkhorn' and self.debiased_loss) else 0