CuriousAI / mean-teacher

A state-of-the-art semi-supervised method for image recognition
https://arxiv.org/abs/1703.01780
Other
1.56k stars 331 forks source link

How to train the model with unlabeled data? #59

Open broken-dream opened 2 years ago

broken-dream commented 2 years ago

I want to transfer the MT framework to a NLP task but I don't understand how to train it with unlabeled data. I have got the idea of the paper, but i'm confusing about the implementation.

    if isinstance(model_out, Variable):
        assert args.logit_distance_cost < 0
        logit1 = model_out
        ema_logit = ema_model_out
    else:
        assert len(model_out) == 2
        assert len(ema_model_out) == 2
        logit1, logit2 = model_out
        ema_logit, _ = ema_model_out

    ema_logit = Variable(ema_logit.detach().data, requires_grad=False)

    if args.logit_distance_cost >= 0:
        class_logit, cons_logit = logit1, logit2
        res_loss = args.logit_distance_cost * residual_logit_criterion(class_logit, cons_logit) / minibatch_size
        meters.update('res_loss', res_loss.data[0])
    else:
        class_logit, cons_logit = logit1, logit1
        res_loss = 0

    class_loss = class_criterion(class_logit, target_var) / minibatch_size
    meters.update('class_loss', class_loss.data[0])

    ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
    meters.update('ema_class_loss', ema_class_loss.data[0])

    if args.consistency:
        consistency_weight = get_current_consistency_weight(epoch)
        meters.update('cons_weight', consistency_weight)
        consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size
        meters.update('cons_loss', consistency_loss.data[0])
    else:
        consistency_loss = 0
        meters.update('cons_loss', 0)

I notice that the TwoStreamBatchSampler divides the dataset into labeled part and unlabeled part, but the code above seems handles both labeled and unlabeled data in a universal way. I think only the labeled part of model_out should be used to calculate the class_loss. Did I get it wrong?

jetoHui520 commented 2 years ago

I think you maybe want to take a look of the data.py and datasets.py,which are the modules to load your dataset to the trainning model.

amstrudy commented 1 month ago

@tarvaina So sorry to bother you several years on, but I have this same question. During training when some number of the samples in each batch are unlabeled (-1) how is it that the class loss is being calculated on both labeled and unlabeled samples?

In this line here: class_loss = class_criterion(class_logit, target_var) / minibatch_size we have class_logit, which for some batch size of 8 should be of dimension 8 x 1000. Then we have target_var, which is some tensor of length 8 representing the classes (ex: [-1, -1, -1, -1, 2, 4, 5, 5].

Can you clarify how this is working? Why are the unlabeled samples being taken into account? Thank you so much!

amstrudy commented 1 month ago

Replying in case this is helpful for others: when the class_criterion is created, NO_LABEL labels are set to be ignored. I believe this handles things correctly! class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL).cuda()