graykode / nlp-tutorial

Natural Language Processing Tutorial for Deep Learning Researchers
https://www.reddit.com/r/MachineLearning/comments/amfinl/project_nlptutoral_repository_who_is_studying/
MIT License
14.07k stars 3.91k forks source link

doubts about the TextCNN Code #33

Closed ZihaoZheng98 closed 4 years ago

ZihaoZheng98 commented 5 years ago

`class TextCNN(nn.Module): def init(self): super(TextCNN, self).init()

    self.num_filters_total = num_filters * len(filter_sizes)
    self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
    self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
    self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)

def forward(self, X):
    embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
    embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]

    pooled_outputs = []
    for filter_size in filter_sizes:
        # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
        conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
        h = F.relu(conv)
        # mp : ((filter_height, filter_width))
        mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
        # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
        pooled = mp(h).permute(0, 3, 2, 1)
        pooled_outputs.append(pooled)

    h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
    h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]

    model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
    return model`

I wonder if it's wrong to create conv inside the loop?

endeavor11 commented 4 years ago

yes, I think so. The author may like use Tensorflow1.x