lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

input_mask behavior #27

Open AliOskooeiTR opened 3 years ago

AliOskooeiTR commented 3 years ago

I have a question about how the input_mask works in RoutingTransformerLM. I have been using a random mask (with causal =False), as used in MLM and playing with the masking ratio but it appears that the ratio is not really affecting how the model learns. I even went to the extremes and masked 90% of the inputs and yet the model continued to learn rapidly. I am training the LM with HuggingFace Trainer. I am copying below my compute_loss method for reference. I have tested the mask itself and the input data and they're fine.

def compute_loss(self, model, inputs):

      model_dim = self.args.model_dim
      model_seq_len = self.args.model_seq_len

      source = inputs["input_ids"].to(self.args.device)
      input_mask = torch.ones_like(source).bool().to(self.args.device)
      masked_tokens = random.sample(
          range(source.shape[1]),
          int(self.args.mask_ratio*source.shape[1])
      )
      input_mask[0, masked_tokens] = torch.tensor(False).to(self.args.device)

      output, aux_loss = model(
          source,
          input_mask=input_mask,
          return_loss=True
      )
      loss = F.cross_entropy(
          output.transpose(1, 2),
          source
      ) + aux_loss

      return loss.squeeze()