marcoaversa / diffinfinite

DiffInfinite Official Code
MIT License
22 stars 3 forks source link

Why training only 1 class at a time ? #9

Open joihn opened 2 months ago

joihn commented 2 months ago

Hello, thanks for this repo, about this line:

mask=torch.where(mask==self.subclasses[self.tmp_index-1], self.tmp_index, 0)

https://github.com/marcoaversa/diffinfinite/blob/master/dataset.py#L353

If I understand correctly: The guidance mechanism used here trains only 1 class at a time In other word, a mask can contains only zeros (unkown class) and 1 other class (specifically selected for this sample) If there are other class in this mask, they are sent to 0

What is the reason behind this ? Why can't a mask contains multiple class during training ? for example class idx_1 on the righthalf, class idx_2 on the left half ?

Making proper use of all the label pixel available at each iteration would significantly speed up training no ?

joihn commented 2 months ago

Hum, maybe it's for minimizing intra-mask class imbalance influence ?