lhoyer / DAFormer

[CVPR22] Official Implementation of DAFormer: Improving Network Architectures and Training Strategies for Domain-Adaptive Semantic Segmentation
Other
466 stars 92 forks source link

The get_class_mask function #19

Closed JunXieFront closed 2 years ago

JunXieFront commented 2 years ago
def get_class_masks(labels):
    class_masks = []
    for label in labels:
        classes = torch.unique(labels)
        nclasses = classes.shape[0]
        class_choice = np.random.choice(
            nclasses, int((nclasses + nclasses % 2) / 2), replace=False)
        classes = classes[torch.Tensor(class_choice).long()]
        class_masks.append(generate_class_mask(label, classes).unsqueeze(0))
    return class_masks

Is classes = torch.unique(labels) should be classes = torch.unique(label)?

lhoyer commented 2 years ago

Yes, you are right that our implementation is slightly different compared to the original DACS version. Instead of considering a single sample when selecting half of the classes for ClassMix, we consider all classes in a batch. In practice, this means that possibly fewer than half of the classes in a sample are selected for ClassMix. Even though this difference was unintended, it still works well. Therefore, we will keep it in order to provide the source code that is consistent with the results in our paper.