xuanyuansen / scalaLSTM

Using scala to implement tiny LSTM, mainly focusing on the BPTT process of training the network.
Apache License 2.0
20 stars 11 forks source link

深入理解LSTM的BPTT算法

LSTM网络结构

关于LSTM网络的结构可以阅读这篇文章:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

这里需要注意文章最后提及的LSTM两种变形,第一种是加入peephole,使得gate layer能够回溯前一个cell的状态,这增加了一些复杂度;第二种是GRU,将gate layer和forget layer合并为一个update layer,降低了复杂度。

LSTM网络的训练

LSTM的训练使用了BPTT算法,需要重要理解的一点是BPTT算法相当于BP算法扩展到序列(时序)数据,另一个需要理解的点是LSTM是recurrent neural network(这里注意理解recurrent neural network和recursive neural network的区别),BPTT算法在计算中要注意这一点。

LSTM的计算图Compute Graph

SCALA实现

利用SPARK实现minibatch方式的训练

几种常见的LSTM结构