Closed YueZhengMeng closed 4 months ago
原解答代码中,更新 H2 时只使用了上一层的输入 H1 ,这与RNN的原理不符
def rnn(inputs, state, params): # inputs的形状:(时间步数量,批量大小,词表大小) W_xh1, W_hh1, b_h1, W_hh2, b_h2, W_hq, b_q = params H1, H2 = state outputs = [] # X的形状:(批量大小,词表大小) for X in inputs: H1 = torch.relu(torch.mm(X, W_xh1) + torch.mm(H1, W_hh1) + b_h1) H2 = torch.relu(torch.mm(H1, W_hh2) + b_h2) Y = torch.mm(H2, W_hq) + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H1, H2)
个人认为以下实现更合理
def rnn(inputs, state, params): # inputs的形状:(时间步数量,批量大小,词表大小) W_xh1, W_hh1, b_h1, W_xh2, W_hh2, b_h2, W_hq, b_q = params H1, H2 = state outputs = [] # X的形状:(批量大小,词表大小) for X in inputs: # 第一层RNN的计算 H1 = torch.relu(torch.mm(X, W_xh1) + torch.mm(H1, W_hh1) + b_h1) # 第二层RNN的计算 H2 = torch.relu(torch.mm(H1, W_xh2) + torch.mm(H2, W_hh2) + b_h2) # 输出层的计算 Y = torch.mm(H2, W_hq) + b_q outputs.append(Y) # 将所有输出连接在一起 return torch.cat(outputs, dim=0), (H1, H2)
我直接修复并提pr
原解答代码中,更新 H2 时只使用了上一层的输入 H1 ,这与RNN的原理不符
个人认为以下实现更合理