NELSONZHAO / zhihu

This repo contains the source code in my personal column (https://zhuanlan.zhihu.com/zhaoyeyu), implemented using Python 3.6. Including Natural Language Processing and Computer Vision projects, such as text generation, machine translation, deep convolution GAN and other actual combat code.
https://zhuanlan.zhihu.com/zhaoyeyu
3.5k stars 2.14k forks source link

anna_lstm #15

Closed dulm closed 6 years ago

dulm commented 6 years ago

def build_lstm(lstm_size, num_layers, batch_size, keep_prob): ''' 构建lstm层

keep_prob
lstm_size: lstm隐层中结点数目
num_layers: lstm的隐层数目
batch_size: batch_size
'''
# 构建一个基本lstm单元
lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)

# 添加dropout
drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)

# 堆叠
cell = tf.nn.rnn_cell.MultiRNNCell([drop for _ in range(num_layers)])
initial_state = cell.zero_state(batch_size, tf.float32)

return cell, initial_state

改为:

def build_lstm(lstm_size, num_layers, batch_size, keep_prob): ''' 构建lstm层

keep_prob
lstm_size: lstm隐层中结点数目
num_layers: lstm的隐层数目
batch_size: batch_size
'''

cell_list = []
for  i in range(num_layers): 
    # 构建一个基本lstm单元
    lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
    # 添加dropout
    drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    cell_list.append(drop)

# 堆叠
cell = tf.nn.rnn_cell.MultiRNNCell(cell_list)
initial_state = cell.zero_state(batch_size, tf.float32)

return cell, initial_state

否则 /tensorflow/python/ops/rnn_cell_impl.py 中BasicLSTMCell.call内 self._linear._weights 的shape只满足第一个cell, 无满足第二个cell.

NELSONZHAO commented 6 years ago

感谢~