lmnt-com / haste

Haste: a fast, simple, and open RNN library
Apache License 2.0
325 stars 27 forks source link

README PyTorch example #9

Closed nimz closed 4 years ago

nimz commented 4 years ago

This is minor, but the three PyTorch layers defined in the README should be put on the GPU, e.g.

norm_lstm_layer = haste.LayerNormLSTM(input_size=128, hidden_size=256, zoneout=0.1, dropout=0.05).cuda()

since the input is a CUDA tensor.

sharvil commented 4 years ago

Thanks so much for pointing this out, @nimz. Fixed!