Element-Research / rnn

Recurrent Neural Network library for Torch7's nn
BSD 3-Clause "New" or "Revised" License
939 stars 313 forks source link

Multi-function recurrent unit #344

Closed juesato closed 7 years ago

juesato commented 7 years ago

Implementation of https://github.com/Element-Research/rnn/issues/269

A few notes:

I think this should merge after https://github.com/torch/nn/pull/954 since until then the performance is poor, but I wanted to get feedback on this before adding unit tests and documentation.

Let me know if you have any questions!

@nicholas-leonard

nicholas-leonard commented 7 years ago

@juesato @JoostvDoorn @jnhwkim can you review this? Minimum requirements: documentation (README.md), unit tests and code.

jnhwkim commented 7 years ago

@nicholas-leonard For me, I should have to read the paper first. I'll revisit here to catch up.

nicholas-leonard commented 7 years ago

@juesato I would really like to see this get merged. Any developments?

juesato commented 7 years ago

@nicholas-leonard Sorry, I left this hanging. I'l spend the next two hours working on this (speeding up the CMaxTable stuff, adding docs, and unit tests), and if I don't finish then, I'll continue tomorrow.

juesato commented 7 years ago

@nicholas-leonard I got some time to spend on this this weekend, but found a bug, and need a bit more time. I'm hoping to finish up after work tomorrow, but I think it's fairly likely I won't have time, in which case I'll finish on Tuesday.

juesato commented 7 years ago

@nicholas-leonard I added unit tests and documentation, and I believe this should be ready to merge.

As a future reference, I'm going to leave the training curves on PTB here.

Command (I added support for Adam, since that's what's used in the original paper):

CUDA_VISIBLE_DEVICES=0 th examples/recurrent-language-model.lua --progress --cuda --mfru --seqlen 20 --uniform 0.1 --hiddensize '{200}' --batchsize 32 --maxepoch 20 --device 1 --adam --startlr 0.01 --cutoff 5
Epoch #1 :  
 [======================================== 29048/29048 ================================>]  Tot: 2m18s | Step: 4ms       
learning rate   0.009975025 
mean gradParam norm 8.4984952035468 
Speed : 0.004788 sec/batch  
Training PPL : 282.28724615675  
Validation PPL : 195.89575663903    
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7  

Epoch #2 :  
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate   0.00995005  
mean gradParam norm 9.6138208177836 
Speed : 0.004739 sec/batch  
Training PPL : 138.24728917969  
Validation PPL : 152.95716046988    
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7  

Epoch #3 :  
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate   0.009925075 
mean gradParam norm 10.585762301436 
Speed : 0.004736 sec/batch  
Training PPL : 99.841825199886  
Validation PPL : 138.14327248568    
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7  

Epoch #4 :  
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate   0.0099001   
mean gradParam norm 11.394943580532 
Speed : 0.004733 sec/batch  
Training PPL : 78.723720426491  
Validation PPL : 133.5123597645 
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7  

Epoch #5 :  
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate   0.009875125 
mean gradParam norm 12.078977865072 
Speed : 0.004733 sec/batch  
Training PPL : 65.066691855094  
Validation PPL : 134.65883864519    

Epoch #6 :  
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate   0.00985015  
mean gradParam norm 12.725478478451 
Speed : 0.004731 sec/batch  
Training PPL : 55.300139163342  
Validation PPL : 139.33760207709    

As a baseline, if we swap out GRU for MuFuRU, we get this curve

Epoch #1 :  
 [======================================== 29048/29048 ================================>]  Tot: 1m36s | Step: 3ms       
learning rate   0.009975025 
mean gradParam norm 7.7288590271154 
Speed : 0.003351 sec/batch  
Training PPL : 252.14280286606  
Validation PPL : 178.4052220798 
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7  

Epoch #2 :  
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate   0.00995005  
mean gradParam norm 8.4460625165352 
Speed : 0.003365 sec/batch  
Training PPL : 127.50807904334  
Validation PPL : 141.2727270623 
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7  

Epoch #3 :  
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate   0.009925075 
mean gradParam norm 9.040807982285  
Speed : 0.003365 sec/batch  
Training PPL : 94.050081227786  
Validation PPL : 127.12406725727    
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7  

Epoch #4 :  
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate   0.0099001   
mean gradParam norm 9.548776472761  
Speed : 0.003364 sec/batch  
Training PPL : 75.183904110218  
Validation PPL : 121.60937118905    
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7  

Epoch #5 :  
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate   0.009875125 
mean gradParam norm 10.028425562097 
Speed : 0.003355 sec/batch  
Training PPL : 62.757740662728  
Validation PPL : 121.03703331266    
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7  

So the training loss looks similar across the two, but generalization seems better with GRU here. Haven't done any sort of hyperparameter search here, just thought that this could be useful info for users looking to sanity check in the future.

nicholas-leonard commented 7 years ago

@juesato Thanks for following through on this to the end!