Closed WeiFoo closed 7 years ago
This is because of a mistake where the RNN is initialized twice (once at the top, once below creating the optimizer) - so optimizer is optimizing parameters that aren't being used. Commit 0cc55f5aaed44e7903edb8842671411301fcf003 should fix it.
where the RNN is initialized twice (once at the top, once below creating the optimizer)
I don't understand this.
I checked the commit 0cc55f5, it seems that is same as what I did above using optimizer. The resutls are the same, only very limited correct predictions. most are wrong.
The problem is lower down, around https://github.com/spro/practical-pytorch/commit/0cc55f5aaed44e7903edb8842671411301fcf003#diff-e9a91a525ccafb52f5b1e35131d4011cL51
The RNN was being re-created after the optimizer:
rnn = RNN(n_letters, n_hidden, n_categories) # rnn 1
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate) # Using rnn 1's parameters
def train():
...
rnn(...) # This is going to refer to the rnn 2, because of below
optimizer(...) # This still has parameters of rnn 1
...
rnn = RNN(n_letters, n_hidden, n_categories) # rnn 2 causes the problem, delete this
So the optimizer was not working, because rnn
was redefined to and has a completely new set of parameters, while the optimizer has a reference to the old one.
cool, thanks!! The mistake in the official Classifying Names with a Character-Level RNN tutorial should be fixed as well.
import time
import math
n_epochs = 100000
print_every = 5000
plot_every = 1000
rnn = RNN(n_letters, n_hidden, n_categories)
# Keep track of losses for plotting
current_loss = 0
all_losses = []
In the "char-ran-classification" tutorial, the weights are updated by the following code.
I was trying to use
optimizer.step()
to update weights as the following:However, the results are very bad. most of the results are predicted wrong. I'm new to pytorch, and can any one explain what's the difference between these two methods?
Thanks
==================Results with optimizer.step()======================
==================Results using the method in the tutorial======================