kaishengtai / torch-ntm

A Neural Turing Machine implementation in Torch.
279 stars 55 forks source link

model learned for the copy task with default parameters doesn't generalize #7

Open eulerreich opened 9 years ago

eulerreich commented 9 years ago

I modified the copy task to save the model every 100 iterations. After doing

th tasks/copy.lua

with the default parameters, I played around with the model and it is clear that seq and torch.gt(forward(model, seq), .5) are very different for most choices of seq. I saved the loss per sequence in the name of the serialization files, and they suggest that the good performance of the model at the end seems to be a fluke (of the manual random seed 0 hardcoded into copy.lua)

$> ls | tail
epoch09100_3.9718.model                                                                                                                        
epoch09200_4.5562.model                                                                                                                        
epoch09300_4.5259.model                                                                                                                        
epoch09400_4.4691.model                                                                                                                        
epoch09500_5.2147.model                                                                                                                        
epoch09600_3.6268.model                                                                                                                        
epoch09700_1.8399.model                                                                                                                        
epoch09800_1.2273.model                                                                                                                        
epoch09900_0.5149.model                                                                                                                        
epoch10000_0.2789.model 

Compare with

$> ls | head                                                                                                                                   
epoch00100_5.5219.model                                                                                                                        
epoch00200_5.6047.model                                                                                                                        
epoch00300_5.5736.model                                                                                                                        
epoch00400_5.6103.model                                                                                                                        
epoch00500_5.5607.model                                                                                                                        
epoch00600_5.5996.model                                                                                                                        
epoch00700_5.4609.model                                                                                                                        
epoch00800_5.7227.model                                                                                                                        
epoch00900_5.5619.model                                                                                                                        
epoch01000_5.5650.model 

In fact the only times that the loss per sequence goes below 1 is in the last 2 epochs.

Did you get a different result with the default parameters?

kaishengtai commented 9 years ago

Are you seeing this when you run the trained model on new examples, or when training a model from scratch using a different random seed?

eulerreich commented 9 years ago

I trained the model with the default parameters (in particular, the random seed 0 hard coded into the file). After the training finished, I checked it in th with new random examples; for example,

th> seq = generate_sequence(5, 8)                                                                                                              
                                                                      [0.0002s                                                                 
th> output = forward(model, seq)                                                                                                               
                                                                      [0.3955s                                                                 
th> seq                                                                                                                                        
 0  0  0  1  1  1  0  0  0  0                                                                                                                  
 0  0  1  0  1  0  0  0  1  0                                                                                                                  
 0  0  0  0  0  0  1  0  1  0                                                                                                                  
 0  0  0  0  0  1  0  1  0  0                                                                                                                  
 0  0  1  0  1  1  0  0  0  0                                                                                                                  
[torch.DoubleTensor of size 5x10]                                                                                                              

                                                                      [0.0001s                                                                 
th> torch.gt(output, .5)                                                                                                                       
 0  0  0  1  0  1  0  0  1  1                                                                                                                  
 0  0  1  1  1  0  1  1  0  1                                                                                                                  
 0  0  0  0  0  0  0  0  0  1                                                                                                                  
 0  0  0  0  0  0  0  0  0  0                                                                                                                  
 0  0  1  0  1  0  0  0  0  0                                                                                                                  
[torch.ByteTensor of size 5x10]                                     
kaishengtai commented 9 years ago

OK, I see. I'll have a look -- thanks for pointing out this issue.

eulerreich commented 8 years ago

I'm getting the same nonconvergence for the associative recall task. After running for 50000 epochs, my log files indicate the losses fluctuate quite a bit

$> ls results/1442574433/ | tail -n 20                                                                                                         
epoch48100_9.0221.model                                                                                                                        
epoch48200_9.0616.model                                                                                                                        
epoch48300_0.0393.model                                                                                                                        
epoch48400_8.7340.model                                                                                                                        
epoch48500_11.9884.model                                                                                                                       
epoch48600_10.3534.model                                                                                                                       
epoch48700_4.9097.model                                                                                                                        
epoch48800_10.3642.model                                                                                                                       
epoch48900_0.0683.model                                                                                                                        
epoch49000_11.1729.model                                                                                                                       
epoch49100_4.0965.model                                                                                                                        
epoch49200_2.9329.model                                                                                                                        
epoch49300_8.6038.model                                                                                                                        
epoch49400_0.0436.model                                                                                                                        
epoch49500_0.0823.model                                                                                                                        
epoch49600_8.0040.model                                                                                                                        
epoch49700_7.4330.model                                                                                                                        
epoch49800_10.0362.model                                                                                                                       
epoch49900_7.4087.model                                                                                                                        
epoch50000_10.0742.model               

Compare with

$> ls results/1442574433/ | head -n 20                                                                                                         
epoch00100_12.3335.model                                                                                                                       
epoch00200_12.5032.model                                                                                                                       
epoch00300_12.2715.model                                                                                                                       
epoch00400_12.5553.model                                                                                                                       
epoch00500_12.5074.model                                                                                                                       
epoch00600_12.5660.model                                                                                                                       
epoch00700_12.6916.model                                                                                                                       
epoch00800_12.4654.model                                                                                                                       
epoch00900_12.4325.model                                                                                                                       
epoch01000_12.6071.model                                                                                                                       
epoch01100_12.4807.model                                                                                                                       
epoch01200_12.2052.model                                                                                                                       
epoch01300_12.3652.model                                                                                                                       
epoch01400_12.5196.model                                                                                                                       
epoch01500_12.3711.model                                                                                                                       
epoch01600_12.7436.model                                                                                                                       
epoch01700_12.3393.model                                                                                                                       
epoch01800_12.4030.model                                                                                                                       
epoch01900_12.3585.model                                                                                                                       
epoch02000_12.3311.model        

Here is an example output by one of the models saved at the end.

-- query 1 of 4: index 4                                                                                                                       
 0  0  1  0  1  0  1  0                                                                                                                        
 0  0  1  0  0  0  1  1                                                                                                                        
 0  0  1  1  1  1  0  0                                                                                                                        
[torch.DoubleTensor of size 3x8]  
-- target 1 of 4: index 5                                                                                                                      
 0  0  1  0  1  0  1  0                                                                                                                        
 0  0  0  0  1  0  1  1                                                                                                                        
 0  0  0  0  1  1  0  1                                                                                                                        
[torch.DoubleTensor of size 3x8]
-- output 1 of 4                                                                                                                               
 0.0001  0.0001  0.5167  0.4437  0.4385  0.4977  0.9151  0.0787                                                                                
 0.0000  0.0000  0.3638  0.2945  0.4635  0.0279  0.7452  0.7260                                                                                
 0.0000  0.0000  0.1876  0.7161  0.5930  0.3402  0.0610  0.1457                                                                                
[torch.DoubleTensor of size 3x8]