d2l-ai / d2l-en

Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.
https://D2L.ai
Other
23.24k stars 4.27k forks source link

The second training with use_random_iter=True in rnn-scratch does not reset the weights #1678

Closed floriandonhauser closed 3 years ago

floriandonhauser commented 3 years ago

I am currently working on chapters 8 and 9 for the tensorflow version (already merged 8.6). I have found that the last code block in 8.5 continues the training of the previous code block. In my opinion, the training should begin from new.

rnn-scratch-training

The new code to do so would be about 2 lines longer each:

#@tab mxnet
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,
                      init_rnn_state, rnn)
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),
          use_random_iter=True)
#@tab pytorch
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,
                      init_rnn_state, rnn)
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),
          use_random_iter=True)
#@tab tensorflow
with strategy.scope():
    net = RNNModelScratch(len(vocab), num_hiddens, init_rnn_state, rnn, get_params)

train_ch8(net, train_iter, vocab_random_iter, num_hiddens, lr,
          num_epochs, strategy, use_random_iter=True)

@terrytangyuan Which version is correct and should be used for the TF version I am working on? Do we want to restart with use_random_iter=True or continue the training? The whole TensorFlow training loop has a problem where new models will not work with it if they have different parameters. The reason is that get_params is called to early which only works if we have used the model (call) before training. Otherwise, the params of the model have not yet been initialized! I am working on a fix for this as well and will create a PR for 9.1 (GRU) and 9.2 (LSTM) with it most likely tomorrow.

terrytangyuan commented 3 years ago

Let's stick with the behavior in MXNet/PyTorch implementations for now to be consistent.

floriandonhauser commented 3 years ago

I think consistency is important but I suggest we change all three variants in the near future. Other people have also found this problem and have mentioned it in the discussion of this chapter: https://discuss.d2l.ai/t/implementation-of-recurrent-neural-networks-from-scratch/486/9 As soon as I am done with all RNN topics, we can come back to this issue and I will create a PR.

terrytangyuan commented 3 years ago

Okay yeah I'd suggest focusing on new chapters/sections first and make sure to discuss with other maintainers too if the change is involved in all frameworks.

astonzhang commented 3 years ago

@floriandonhauser Thanks! Can you send a separate PR to only fix "resetting the params" in the second run? We'll merge this fix first to ensure correctness.

floriandonhauser commented 3 years ago

@astonzhang Yes, I will first merge my RNN changes with all 3 versions being consistent. Afterwards, I will make the change for all 3 versions at once.

astonzhang commented 3 years ago

@floriandonhauser, would you like to send a PR to fix this issue now or wait after more RNN PRs being merged? Thanks!

floriandonhauser commented 3 years ago

@astonzhang I will work on a PR for this issue now (probably finished tomorrow) and will then continue my work on the rest of the RNN chapters. I also noticed, that train_ch8 has an additional parameter in the TensorFlow implementation (num_hiddens) which is actually unnecessary and unused. I will include this change as well.

For the chapters 9.4 and 9.7 I might not be able to implement them on my own as quickly since TensorFlow does not have parameters for bidirectional and num_hiddens like e.g. PyTorch has. I am trying to find a different solution but am not finished yet and unsure if I get find a solution for everything on my own.

terrytangyuan commented 3 years ago

@floriandonhauser Thank you!

astonzhang commented 3 years ago

@floriandonhauser Understand. Thank you!

floriandonhauser commented 3 years ago

Fixed with PR #1693