arxyzan / data2vec-pytorch

PyTorch implementation of "data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language" from Meta AI
MIT License
172 stars 26 forks source link

EMA model forward #11

Closed anhvth closed 2 years ago

anhvth commented 2 years ago
        # model forward in online mode (student)
        x = self.encoder(src, mask, **kwargs)['encoder_out']  # fetch the last layer outputs
        if trg is None:
            return x

        # model forward in offline mode (teacher)
        with torch.no_grad():
            self.ema.model.eval()
            y = self.ema.model(trg, ~mask, **kwargs)['encoder_states']  # fetch the last transformer layers outputs

In the teacher forward pass the mask_time_indices is the inverse of the one in student, is this correct? I think the mask in the teacher forward pass should be None since the teacher expects the full version of input data

arxyzan commented 2 years ago

Hi @anhvth, thanks for your feedback. The teacher predicts representations from the masked indices in the input (the indices that are masked for src are not masked for trg and vice versa) so the mask must be the inverse of the one in the student.