facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

[data2vec] Trying to implement a standalone pytorch version #4200

Closed arxyzan closed 2 years ago

arxyzan commented 2 years ago

First off, Thank you @alexeib for this great work on data2vec. I'm currently trying to impelement data2vec model in pure PyTorch. I've read the paper and codes and I have a couple of questions:

  1. What exactly is the target_tokens argument in the forward method of Data2VecTextModel and Data2VecTextEncoder?
  2. Why are the encoder outputs (or in other words the representations themselves) masked before applying regression?
    
    # This is a pseudo-code of Data2VecTextEncoder forward method
    x = self.extract_features(src_tokens)
    with torch.no_grad():
    y = self.ema.model(target_tokens)
    y = self.norm(y)

why mask representations?

x = x[masked_indices] y = y[masked_indices] x = self.regression_head(x) loss = self.criterion(x, y) ...

3. Why the outputs used in teacher mode are `fc_results` but in student mode it's `encoder_out`?
4. Why are the losses calculated in the forward method of the data2vec encoders? 
5. For inference, which model should actually be used? `Data2VecTextModel` or the inner encoder model e.g, RoBERTa? My assumption is that the forward method must be called with `features_only` set to `True` which yields encoder outputs and applies classification layer on top in `Data2VecModel` forward method.
6. To implement a general framework that can easily be applied to different encoders and modalities I have the following schema in mind. Is this the best practice? or is it even possible? because in the paper it's mentioned that there must be some modality specific procedures for masking and feature extraction. The way I tackle this problem is by implementing these strategies as methods in the encoder itself.
Here's the prototype code:

```python
# this is a very basic prototype
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ema import EMA

class Data2Vec(nn.Module):
    def __init__(self, encoder, cfg):
        super(Data2Vec, self).__init__()
        self.encoder = Data2VecEncoder(encoder, cfg)
        self.classification_head = nn.Linear(cfg.in_features, cfg.num_classes)

    def forward(self, src, trg=None, do_classification=False):
        src = self.encoder.apply_mask(src) if not do_classification else src
        encoder_output = self.encoder(src, trg, features_only=not do_classification)
        if do_classification:
            classification_output = self.classification_head(encoder_output)
            return classification_output
        else:
            return encoder_output

class Data2VecEncoder(nn.Module):
    MODALITIES = ['vision', 'text', 'audio']

    def __init__(self, encoder: nn.Module, cfg):
        super(Data2VecEncoder, self).__init__()
        self.modality = cfg.modality
        assert cfg.modality in self.MODALITIES
        self.encoder = encoder

        self.cfg = cfg
        self.teacher = EMA(self.encoder, cfg)
        self.regression_head = nn.ModuleList()  # custom layers for projection

    def forward(self, src, trg, features_only=False):
        x = self.encoder.extract_features(src)
        if features_only:
            return x

        with torch.no_grad():
            self.teacher.model.eval()

            y = self.teacher.model(trg)
            y = y[self.cfg.teacher_features]
            y = y[-self.cfg.top_k_layers:]

            if self.modality in ['vision', 'text']:
                y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
                y = sum(y) / len(y)
                y = y.transpose(0, 1)

            elif self.modality == 'audio':
                y = [tl.permute(1, 2, 0) for tl in y]
                y = [F.instance_norm(tl.float()) for tl in y]
                y = [tl.transpose(1, 2) for tl in y]
                y = sum(y) / len(y)

            if self.cfg.layer_norm_targets:
                y = F.layer_norm(y.float(), y.shape[-1:])
            elif self.cfg.instance_norm_targets:
                y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)

        masked_indices = src.eq(self.mask_idx)
        x = x[masked_indices]
        y = y[masked_indices]

        x = self.regression_head(x)

        return x, y

Thanks in advance for any help or recommendations. Best, Aryan

arxyzan commented 2 years ago

Alright, after playing around with the code and reading the paper more carefully I have figured out some of the answers which I put below for anyone wondering:

  1. Pretraining data2vec for text is done as a language modeling task which is configured in the config file. (task : masked_lm) in which the src_tokens are masked version of the inputs and target_tokens are the original unmasked inputs. but in audio the original input is fed to the model and masking is done within the forward method.
  2. We don't care about those tokens so they're popped from the tensor. (for inputs and targets both)
  3. This is detailed in table 4. in the paper. Using FFN gains the least error rate in practice.
  4. Data2VecText is an MLM task but it actually uses fairseq.criterions.model_criterion.ModelCriterion which relies on the model to provide the losses and the reason to not use fairseq.criterions.masked_lm.MaskedLmLoss like in RoBERTa is that data2vec uses either MSE or L1 loss instead of cross entropy loss provided in fairseq.modules.cross_entropy so the losses are better to be calculated inside the forward method of Data2Vec.
  5. The assumption is correct. Finetuning the model is done by using the encoder itself + projection layers for classification, etc.
  6. It seems possible. You can follow my implementation in my repo