kaituoxu / Listen-Attend-Spell

A PyTorch implementation of Listen, Attend and Spell (LAS), an End-to-End ASR framework.
200 stars 56 forks source link

train from previous checkpoint #11

Open MNCTTY opened 5 years ago

MNCTTY commented 5 years ago

Hi

I tried to train model from previous checkpoint

For example, I trained the model during 100 epochs and got the final.pth.tar file. I put the abs path to it in the run.sh in lines:

...
# logging and visualize
checkpoint=0
continue_from="/home/karina/Listen-Attend-Spell/egs/aishell/exp/train_in240_hidden256_e3_lstm_drop0.2_dot_emb512_hidden512_d1_epoch100_norm5_bs64_mli800_mlo150_adam_lr1e-3_mmt0_l21e-5_delta/final.pth.tar"
print_freq=10
visdom=0
visdom_id="LAS Training"
...

but training exiting with this log:

# train.py --train_json dump/train/deltatrue/data.json --valid_json dump/dev/deltatrue/data.json --dict data/lang_1char/train_chars.txt --einput 240 --ehidden 256 --elayer 3 --edropout 0.2 --ebidirectional 1 --etype lstm --atype dot --dembed 512 --dhidden 512 --dlayer 1 --epochs 10 --half_lr 1 --early_stop 0 --max_norm 5 --batch_size 64 --maxlen_in 800 --maxlen_out 150 --optimizer adam --lr 1e-3 --momentum 0 --l2 1e-5 --save_folder exp/train_in240_hidden256_e3_lstm_drop0.2_dot_emb512_hidden512_d1_epoch10_norm5_bs64_mli800_mlo150_adam_lr1e-3_mmt0_l21e-5_delta --checkpoint 1 --continue_from /home/karina/Listen-Attend-Spell/egs/aishell/exp/train_in240_hidden256_e3_lstm_drop0.2_dot_emb512_hidden512_d1_epoch100_norm5_bs64_mli800_mlo150_adam_lr1e-3_mmt0_l21e-5_delta/final.pth.tar --print_freq 10 --visdom 0 --visdom_id "LAS Training" 
# Started at Fri Sep 13 03:00:41 MSK 2019
#
Namespace(atype='dot', batch_size=64, checkpoint=1, continue_from='/home/karina/Listen-Attend-Spell/egs/aishell/exp/train_in240_hidden256_e3_lstm_drop0.2_dot_emb512_hidden512_d1_epoch100_norm5_bs64_mli800_mlo150_adam_lr1e-3_mmt0_l21e-5_delta/final.pth.tar', dembed=512, dhidden=512, dict='data/lang_1char/train_chars.txt', dlayer=1, early_stop=0, ebidirectional=1, edropout=0.2, ehidden=256, einput=240, elayer=3, epochs=10, etype='lstm', half_lr=1, l2=1e-05, lr=0.001, max_norm=5.0, maxlen_in=800, maxlen_out=150, model_path='final.pth.tar', momentum=0.0, num_workers=4, optimizer='adam', print_freq=10, save_folder='exp/train_in240_hidden256_e3_lstm_drop0.2_dot_emb512_hidden512_d1_epoch10_norm5_bs64_mli800_mlo150_adam_lr1e-3_mmt0_l21e-5_delta', train_json='dump/train/deltatrue/data.json', valid_json='dump/dev/deltatrue/data.json', visdom=0, visdom_id='LAS Training')
Seq2Seq(
  (encoder): Encoder(
    (rnn): LSTM(240, 256, num_layers=3, batch_first=True, dropout=0.2, bidirectional=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(38, 512)
    (rnn): ModuleList(
      (0): LSTMCell(1024, 512)
    )
    (attention): DotProductAttention()
    (mlp): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): Tanh()
      (2): Linear(in_features=512, out_features=38, bias=True)
    )
  )
)
Loading checkpoint model /home/karina/Listen-Attend-Spell/egs/aishell/exp/train_in240_hidden256_e3_lstm_drop0.2_dot_emb512_hidden512_d1_epoch100_norm5_bs64_mli800_mlo150_adam_lr1e-3_mmt0_l21e-5_delta/final.pth.tar
Traceback (most recent call last):
  File "/home/karina/Listen-Attend-Spell/egs/aishell/../../src/bin/train.py", line 146, in <module>
    main(args)
  File "/home/karina/Listen-Attend-Spell/egs/aishell/../../src/bin/train.py", line 139, in main
    solver = Solver(data, model, optimizier, args)
  File "/home/karina/Listen-Attend-Spell/src/solver/solver.py", line 43, in __init__
    self._reset()
  File "/home/karina/Listen-Attend-Spell/src/solver/solver.py", line 53, in _reset
    self.tr_loss[:self.start_epoch] = package['tr_loss'][:self.start_epoch]
RuntimeError: The expanded size of the tensor (10) must match the existing size (13) at non-singleton dimension 0.  Target sizes: [10].  Tensor sizes: [13]
# Accounting: time=4 threads=1
# Ended (code 1) at Fri Sep 13 03:00:45 MSK 2019, elapsed time 4 seconds

what object can give this tensor size problem? do I correctly use training from checkpoint?

joepareti54 commented 3 years ago

yes, checkpointing is a must-have for this type of of training especially if you plan on training on a single gpu. Is there any progress? See below my considerations:

given it takes a very long time to train the model, it is essential to be able to do checkpoints. Is the checkpoint supported? How can I make sure a long simulation is checkpointed? Why is it not documented?

Here is what I see in the code:

in train.py : parser.add_argument('--checkpoint', dest='checkpoint', default=0, type=int, help='Enables checkpoint saving of model')

in solver.py :

Save model each epoch

        if self.checkpoint:
            file_path = os.path.join(
                self.save_folder, 'epoch%d.pth.tar' % (epoch + 1))
            torch.save(self.model.serialize(self.model, self.optimizer, epoch + 1,
                                            tr_loss=self.tr_loss,
                                            cv_loss=self.cv_loss),
                       file_path)
            print('Saving checkpoint model to %s' % file_path)

KALDI