salesforce / awd-lstm-lm

LSTM and QRNN Language Model Toolkit for PyTorch
BSD 3-Clause "New" or "Revised" License
1.96k stars 488 forks source link

- Update for pytorch 0.2. - Fix LockedDropout to broadcast correct axis. - Use relative path for default data source. #5

Closed jph00 closed 7 years ago

salesforce-cla[bot] commented 7 years ago

Thanks for the contribution! Before we can merge this, we need @racheltho to sign the Salesforce Contributor License Agreement.

Smerity commented 7 years ago

As a heads up, I've not forgotten about this ;)

The default PTB parameters as given in the README result in an untuned model validation perplexity of 65.2 which is far from the 61.2 that existed before the v0.2 update. My current theory is that it's a result of the fixed variational / locked dropout - either locked dropout doesn't work as well as expected (which may well be possible as I've never seen it directly compared for non-recurrent connections where it has a massive advantage) or it requires substantially different hyperparameters.

I am currently testing what happens if you just run with v0.2 with normal dropout. Once I've ascertained the source of discrepancy I'll move forward with merging :)

jph00 commented 7 years ago

I'll be interested to hear what you find!

Smerity commented 7 years ago

Unfortunately the weight dropout update appears to be broken. After testing [fixed locked dropout, Smerity's silly locked dropout, normal dropout] and finding the same problematic convergence, I had a bug hunt.

import torch
from weight_drop import WeightDrop

wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9)

# Input is (seq, batch, input)
x = torch.autograd.Variable(torch.randn(2, 1, 10))
h0 = None

print(wdrnn.module._all_weights)

run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
run2 = [x.sum() for x in wdrnn(x, h0)[0].data]

print('Run 1:', run1)
print('Run 2:', run2)

# First time step, not influenced by hidden to hidden weights, should be equal
assert run1[0] == run2[0]
# Second step should not
assert run1[1] != run2[1]

On PyTorch 0.2 we end up with:

Applying weight drop of 0.9 to weight_hh_l0
[['weight_ih_l0', 'weight_hh_l0_raw', 'bias_ih_l0', 'bias_hh_l0']]
Run 1: [-0.12040333496406674, -0.45819888915866613]
Run 2: [-0.12040333496406674, -0.45819888915866613]
Traceback (most recent call last):
  File "wdrop_test.py", line 21, in <module>
    assert run1[1] != run2[1]
AssertionError

while on PyTorch 0.1.12 we end up with:

Applying weight drop of 0.9 to weight_hh_l0
[['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']]
Run 1: [-0.003149353200569749, -0.43234687950462103]
Run 2: [-0.003149353200569749, -0.5475110579282045]

This is the expected behaviour, where the first output should always be the same (as we're only performing weight drop on the hidden-to-hidden weight matrix which hasn't been used yet) but the second outputs should always be different (unless we're very unlucky to have the exact same dropout masks :P). I'll also replace the if __name__ == "__main__" mini test in weight_drop with this as it'll be a more useful mini test than whatever in the world I wrote there previously ;)

Further pondering on weight drop may be necessary.

Smerity commented 7 years ago

I've got a fix that I believe works, now testing it, and will submit a pull request to your branch with that singular fix if it should work :)

Related issue (and why this breakage is also similar for weight norm): https://github.com/pytorch/pytorch/issues/2515#issuecomment-327637901

jph00 commented 7 years ago

Yup I just figured out the same thing :) I started working on language modeling for the course today so this came up at just the right time!

Smerity commented 7 years ago

With weight drop working with PyTorch 0.2, python -u main.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.225 --seed 28 --epoch 500 gives 61.2 / 58.9 before finetuning, which approximately matches previous. As locked dropout is now different, the hyper parameters will likely need updating. I also want to retry using weight dropping to perform variational dropout.

I'll merge this and update the README (PyTorch 0.2 instructions, remove exact reproduction as that no longer holds, point to the PyTorch==0.1.12 release).

This also closes #3.