jiesutd / LatticeLSTM

Chinese NER using Lattice LSTM. Code for ACL 2018 paper.
1.79k stars 457 forks source link

作者您好, 我这里有一个问题, 折磨我好久了, 还请进来看看, 希望不吝赐教 #121

Closed ume-technology closed 3 years ago

ume-technology commented 3 years ago

在 latticelstm,py 中定义的 LatticeLSTM 类, 在它的 foward() 方法中, ` def forward(self, input, skip_input_list, hidden=None): """ input: variable (batch, seq_len), batch = 1 skip_input_list: [skip_input, volatile_flag] skip_input: three dimension list, with length is seq_len. Each element is a list of matched word id and its length. example: [[], [[25,13],[2,3]]] 25/13 is word id, 2,3 is word length . skip_input == gaz_list """ volatile_flag = skip_input_list[1] # gaz_list skip_input = skip_input_list[0] if not self.left2right: skip_input = convert_forward_gaz_to_backward(skip_input) input = input.transpose(1, 0) seq_len = input.size(0) # character 信息 batch_size = input.size(1) assert (batch_size == 1) hidden_out = [] memory_out = [] if hidden: (hx, cx) = hidden else: hx = autograd.Variable(torch.zeros(batch_size, self.hidden_dim)) cx = autograd.Variable(torch.zeros(batch_size, self.hidden_dim)) if self.gpu: hx = hx.cuda() cx = cx.cuda()

    id_list = range(seq_len)  # (0, 57)
    if not self.left2right:
        id_list = list(reversed(id_list))
    input_c_list = init_list_of_objects(seq_len)  # len list = 57 : [[],[],[],[],....]
    for t in id_list:
        (hx, cx) = self.rnn(input[t], input_c_list[t], (hx, cx))  # input character 数据 sequence
        hidden_out.append(hx)
        memory_out.append(cx)
        if skip_input[t]:
            matched_num = len(skip_input[t][0])  # word  idx in gaz # 处理匹配到的词汇信息
            word_var = autograd.Variable(torch.LongTensor(skip_input[t][0]), volatile=volatile_flag)
            if self.gpu:
                word_var = word_var.cuda()
            word_emb = self.word_emb(word_var)  # 3 * 50 # x1 时刻有 3 个单词
            word_emb = self.word_dropout(word_emb)
            ct = self.word_rnn(word_emb, (hx, cx))  # words 词汇信息 input
            assert (ct.size(0) == len(skip_input[t][1]))
            for idx in range(matched_num):  # matched_num: 匹配到的词汇数量
                length = skip_input[t][1][idx]  # character 得到的词汇的 length 信息,帮助确定词汇截至位置
                if self.left2right:
                    # if t+length <= seq_len -1:  # 每一行的 100 个数据的 vector,unsqueeze --> 1 * 100
                    input_c_list[t + length - 1].append(ct[idx, :].unsqueeze(0))
                else:
                    # if t-length >=0:
                    input_c_list[t - length + 1].append(ct[idx, :].unsqueeze(0))
    if not self.left2right:
        hidden_out = list(reversed(hidden_out))
        memory_out = list(reversed(memory_out))
    output_hidden, output_memory = torch.cat(hidden_out, 0), torch.cat(memory_out, 0)
    return output_hidden.unsqueeze(0), output_memory.unsqueeze(0)`

当模型的下述部分计算完毕后: ct = self.word_rnn(word_emb, (hx, cx)) 这个 ct 数据在哪里被使用了, 我实在是看不懂了, 看了好多文章, 都讲的是和一个 LSTM 输出做拼接进一步计算了, 那么这里说的 LSTM 是在哪里被引入的, 这个 ct 数据的使用是又是在代码的哪里体现的呢? 求求解答. 谢谢啊~

jiesutd commented 3 years ago

你在这段代码里直接搜索ct 就能看到哪里使用ct 了。 ct被存在input_c_list里面,然后input_c_list 被(hx, cx) = self.rnn(input[t], input_c_list[t], (hx, cx))调用了。

image

ume-technology commented 3 years ago

这个问题是我之前看了一张结构图导致的疑问, 那张结构图让我对流程产生了误解才会有这个 issue, 后来我意识到是把隐藏信息借助 for 循环回传到网络中. 多谢您的指导.

On Tue, Dec 15, 2020 at 5:07 PM Jie Yang notifications@github.com wrote:

Closed #121 https://github.com/jiesutd/LatticeLSTM/issues/121.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/jiesutd/LatticeLSTM/issues/121#event-4113456368, or unsubscribe https://github.com/notifications/unsubscribe-auth/ARTUTZ35JTXHOZR23U54GVTSU4RODANCNFSM4UONONVQ .