Louis-udm / VGCN-BERT

MIT License
122 stars 35 forks source link

Unable to understand how 16 dimensional GCN embedding is passed into the BERT transformer #23

Open rangasaishreyas opened 3 years ago

rangasaishreyas commented 3 years ago

I am trying to understand the method in which GCN output is passed into the BERT model

The section of code this happens is in model_vgcn_bert.py ` words_embeddings = self.word_embeddings(input_ids)

    vocab_input=gcn_swop_eye.matmul(words_embeddings).transpose(1,2)       
    if self.gcn_embedding_dim>0:
        gcn_vocab_out = self.vocab_gcn(vocab_adj_list, vocab_input)

        gcn_words_embeddings=words_embeddings.clone()
        for i in range(self.gcn_embedding_dim):
            tmp_pos=(attention_mask.sum(-1)-2-self.gcn_embedding_dim+1+i)+torch.arange(0,input_ids.shape[0]).to(input_ids.device)*input_ids.shape[1]
            gcn_words_embeddings.flatten(start_dim=0, end_dim=1)[tmp_pos,:]=gcn_vocab_out[:,:,i]`

Here for a sample batch of size [16,40], I get word_embeddings as shape [16,40,768], The gcn_vocab_out has shape [16,768,16]. But at the end of the for loop, gcn_vocab_out content is somehow copied into the shape of word_embeddings tensor and passed into the BERT model. Can you explain what this section of code means?

Also, in the paper it is mentioned that the graph embedding are added as an extension to the bert word embedding sequence. But the code replaces it by using the gcn_words_embeddings instead of words_embeddings. Can you please elaborate on this?

Thanks.

Louis-udm commented 1 year ago

I implemented a new VGCN-BERT version, much faster. And this old algorithm is deprecated. the new version is available in HuggingFace hub: https://huggingface.co/zhibinlu/vgcn-bert-distilbert-base-uncased