Smerity / sha-rnn

Single Headed Attention RNN - "Stop thinking with your head"
1.18k stars 133 forks source link

Randomly zeroing out hidden and memory during training #16

Open alisafaya opened 3 years ago

alisafaya commented 3 years ago

I've successfully reimplemented your work in Julia / Knet DL framework here SHA-RNN.jl. During training I've faced some problems with the first batch of the dataset. Since there is no previous hidden or attn memory exists, the model finds it hard to predict the right output. And during training the model see this case only once in every epoch. To deal with this issue I saw this part in your main file:

            if hidden is not None:
                #if np.random.random() > 0.975:
                #    hidden = None
                #    #hidden = zero_hidden(hidden)
                hidden = repackage_hidden(hidden)
            if mems is not None:
                #if np.random.random() > 0.975:
                #    mems = None
                #    mems = zero_hidden(mems)
                mems = repackage_hidden(mems)

This seems a proper solution to the problem. But you've commented it. Why did you disabled this part? did not this approach help ?

Thanks!

Smerity commented 3 years ago

Apologies for the delayed reply.

Brilliant work on the Julia / Knet implementation! I've looked towards Julia with interest and curiosity given the many advantages it offers. The more DL examples on that side of the fence the better! =]

Regarding the first batch problem, you are entirely correct. The SHA-RNN codebase is optimized for final perplexity on enwik8 and similar documents however and hence rarely has to deal with "first batches". For the model to learn how to deal with them effectively generally means worse performance on long form documents.

If you were interested in tailoring your model for handling such "first batches" you could indeed do what was in the codebase by zeroing out the hidden state. Better than that however would be to store an initial hidden state that's updated via gradients during model training. This doesn't make sense for the model I wrote as there are only a few dozen examples per epoch of "first batches".

The extreme version of this would be to consume part of the input and then select between K initial hidden states, each tailored for a different category of input, and then running from there.

lunixbochs commented 3 years ago

Better than that however would be to store an initial hidden state that's updated via gradients during model training.

I might try this, so to make sure I'm understanding correctly, I should do something like this?