namisan / mt-dnn

Multi-Task Deep Neural Networks for Natural Language Understanding
MIT License
2.22k stars 412 forks source link

Problem in SMART embedding #230

Open kongds opened 2 years ago

kongds commented 2 years ago

Thank you for provided code of SMART. SMART uses the following code to get the embeddings, which is then used to get noisy embeddings and feed bert as inputs_embeds. https://github.com/namisan/mt-dnn/blob/ca896ef1f9de561f1741221d2c98b4d989e3ed19/mt_dnn/matcher.py#L124

But for inputs_embeds in transformers, it should be the output of bert.embeddings.word_embedding not bert.embeddings. Please refer to the following code for BertEmbedding.forward:

    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
        # issue #5664
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
kongds commented 2 years ago

Another problem is that the magnitude of delta_grad * self.step_size is too small to influence the noise during noise updating. For example, delta_grad * self.step_size is around ~ 1e-13, but the magnitude of noise can be 1e-5.

So the code below seems to just update the noise by using its norm without using delta_grad. https://github.com/namisan/mt-dnn/blob/ca896ef1f9de561f1741221d2c98b4d989e3ed19/mt_dnn/perturbation.py#L131-L134