ant-research / StructuredLM_RTDT

A library for building hierarchical text representation and corresponding downstream applications.
Apache License 2.0
76 stars 15 forks source link

an potential issue found for the nn.MultiheadAttention module setup #2

Closed frankaging closed 2 years ago

frankaging commented 2 years ago

Hi Authors!

Thanks for sharing this repo, I enjoyed when reading your paper, and I am working on a related project. As I am going through the code, I found one potential issue with the current setup. I will (1) explain the issue, and (2) provide a simple test case that I ran on my end. Please help with verifying.

Issue:

Code Analysis: In r2d2.py, it is calling the encoder here, as the following

        tasks_embedding = self.embedding(task_ids)  # (?, 2, dim)
        input_embedding = torch.cat([tasks_embedding, tensor_batch], dim=1)  # (?, 4, dim)
        outputs = self.tree_encoder(input_embedding.transpose(0, 1)).transpose(0, 1)  # (? * batch_size, 4, dim)

We can see that input_embedding is definitely with the first dimension being the batch_size as it concat with the embeddings from the nn.embedding module. Before we call self.tree_encoder, .transpose(0, 1) makes the the second dimension of the input being the batch_size instead. Specifically, the first dimension, in this case, is always 4.

Testing Done: I simply add some logs inside TreeEncoderLayer as,

    def forward(self, src, src_mask=None, pos_ids=None):
        """
        :param src: concatenation of task embeddings and representation for left and right.
                    src shape: (task_embeddings + left + right, batch_size, dim)
        :param src_mask:
        :param pos_ids:
        :return:
        """
        if len(pos_ids.shape) == 1:
            sz = src.shape[0]  # sz: batch_size
            pos_ids = pos_ids.unsqueeze(0).expand(sz, -1)  # (3, batch_size)
        position_embedding = self.position_embedding(pos_ids)
        print("pre: ", src.shape)
        print("pos_emb: ", position_embedding.shape)
        output = self.self_attn(src + position_embedding, src + position_embedding, src, attn_mask=src_mask)
        src2 = output[0]
        attn_weights = output[1]
        print("attn_w: ", attn_weights.shape)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        print("post: ", src.shape)
        return src

And this is what I get,

pre:  torch.Size([4, 8, 768])
pos_emb:  torch.Size([4, 8, 768])
attn_w:  torch.Size([4, 8, 8])
post:  torch.Size([4, 8, 768])

Summary: It seems like for r2d2.py, the self-attention is not on those 4 tokens (2 special prefix + left and right children embedding), but it is on the full collection of candidates with their children.

I saw this issue is not presented in r2d2_cuda.py as,

            outputs = self.tree_encoder(
                input_embedding)  # (? * batch_size, 4, dim)

This is great. I have not checked the rest of the code for r2d2_cuda.py though. With this, I am wondering are the numbers from either of your papers need to be updated with this potential issue? Either way, I am not blocked by this potential issue, and I was inspired quite a lot by your codebase. Thanks!

imhuim982 commented 2 years ago

Many thanks for your feedback. Actually, r2d2.py is used in the first paper accepted by ACL 2021, and r2d2_cuda.py is used in the paper accepted by EMNLP. As r2d2.py is not used in our latest work, so we didn't do a regression test which may cause some discrepancies. If you are interested in the first work, I've made a tag for it: https://github.com/alipay/StructuredLM_RTDT/tree/r2d2, which is the original code for the first paper. The current branch actually only supports the cuda version(r2d2_cuda.py). Since r2d2.py actually is legacy code, we'll consider fixing the discrepancy or removing it directly. But the numbers in the paper are running in the correct version. If you have a Cuda environment, I suggest you use the latest version(Fast-R2D2), which is almost 30 folds faster than R2D2, with better downstream tasks performance.

frankaging commented 2 years ago

Thanks for your quick response! Closing the issue as this is not found in the r2d2 repo.

imhuim982 commented 2 years ago

I've checked in a bug-fixed version of r2d2 in the latest branch. We will release a model pretrained on wiki103 of fast-r2d2 soon, hope that will be helpful to your work :)