Closed ume-technology closed 3 years ago
你在这段代码里直接搜索ct 就能看到哪里使用ct 了。
ct被存在input_c_list里面,然后input_c_list 被(hx, cx) = self.rnn(input[t], input_c_list[t], (hx, cx))
调用了。
这个问题是我之前看了一张结构图导致的疑问, 那张结构图让我对流程产生了误解才会有这个 issue, 后来我意识到是把隐藏信息借助 for 循环回传到网络中. 多谢您的指导.
On Tue, Dec 15, 2020 at 5:07 PM Jie Yang notifications@github.com wrote:
Closed #121 https://github.com/jiesutd/LatticeLSTM/issues/121.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/jiesutd/LatticeLSTM/issues/121#event-4113456368, or unsubscribe https://github.com/notifications/unsubscribe-auth/ARTUTZ35JTXHOZR23U54GVTSU4RODANCNFSM4UONONVQ .
在 latticelstm,py 中定义的 LatticeLSTM 类, 在它的 foward() 方法中, ` def forward(self, input, skip_input_list, hidden=None): """ input: variable (batch, seq_len), batch = 1 skip_input_list: [skip_input, volatile_flag] skip_input: three dimension list, with length is seq_len. Each element is a list of matched word id and its length. example: [[], [[25,13],[2,3]]] 25/13 is word id, 2,3 is word length . skip_input == gaz_list """ volatile_flag = skip_input_list[1] # gaz_list skip_input = skip_input_list[0] if not self.left2right: skip_input = convert_forward_gaz_to_backward(skip_input) input = input.transpose(1, 0) seq_len = input.size(0) # character 信息 batch_size = input.size(1) assert (batch_size == 1) hidden_out = [] memory_out = [] if hidden: (hx, cx) = hidden else: hx = autograd.Variable(torch.zeros(batch_size, self.hidden_dim)) cx = autograd.Variable(torch.zeros(batch_size, self.hidden_dim)) if self.gpu: hx = hx.cuda() cx = cx.cuda()
当模型的下述部分计算完毕后:
ct = self.word_rnn(word_emb, (hx, cx))
这个 ct 数据在哪里被使用了, 我实在是看不懂了, 看了好多文章, 都讲的是和一个 LSTM 输出做拼接进一步计算了, 那么这里说的 LSTM 是在哪里被引入的, 这个 ct 数据的使用是又是在代码的哪里体现的呢? 求求解答. 谢谢啊~