nicodjimenez / lstm

Minimal, clean example of lstm neural network training in python, for learning purposes.
1.72k stars 654 forks source link

You forget the tanh function in the last computation in the part of def bottom_data_is(): #47

Open mikechen66 opened 4 years ago

mikechen66 commented 4 years ago

Issue: lstm.py--the 98th line.

There is a problem with the code of line: self.state.h = self.state.s self.state.o. You forget the tanh function. The formula is h{t} = o{t} tanh(s_{t}). Therefore, the correct one is the line of code as follows.

self.state.h = tanh(self.state.s) * self.state.o

Pasted the partial lines of code as follows.

 def bottom_data_is(self, x, s_prev = None, h_prev = None):
    # if this is the first lstm node in the network
    if s_prev is None: s_prev = np.zeros_like(self.state.s)
    if h_prev is None: h_prev = np.zeros_like(self.state.h)
    # save data for use in backprop
    self.s_prev = s_prev
    self.h_prev = h_prev

    # concatenate x(t) and h(t-1)
    xc = np.hstack((x,  h_prev))
    self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg)
    self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi)
    self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf)
    self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo)
    self.state.s = self.state.g * self.state.i + s_prev * self.state.f
    self.state.h = self.state.s * self.state.o
try1995 commented 3 years ago

self.state.h = self.state.o * np.tanh(self.state.s)

cs-heibao commented 3 years ago

@mikechen66 and also exists problem when do backpropagation, ignoring the derivation of tanh function

    def top_diff_is(self, top_diff_h, top_diff_s):
        # notice that top_diff_s is carried along the constant error carousel
        ds = self.state.o * top_diff_h + top_diff_s
        do = self.state.s * top_diff_h
        di = self.state.g * ds
        dg = self.state.i * ds
        df = self.s_prev * ds
ds = self.state.o *(1-self.state.s^2)* top_diff_h + top_diff_s;
do = np.tanh(self.state.s) * top_diff_h
nicodjimenez commented 2 years ago

I think you're right some / most implementations use the tanh but that's not how I defined the forward pass in the blog article:

https://nicodjimenez.github.io/2014/08/08/lstm.html image

If you want to make a PR to add that as an option, that's fine with me.

bot66 commented 2 years ago

yes