fengxinjie / Transformer-OCR

MIT License
320 stars 74 forks source link

@delveintodetail Is there wrong in the predict.py file? #9

Open delveintodetail opened 4 years ago

delveintodetail commented 4 years ago

@delveintodetail Is there wrong in the predict.py file?

Originally posted by @li10141110 in https://github.com/fengxinjie/Transformer-OCR/issues/4#issuecomment-607083558

for epoch in range(10000): model.train() run_epoch(train_dataloader, model, SimpleLossCompute(model.generator, criterion, model_opt)) model.eval() test_loss = run_epoch(val_dataloader, model, SimpleLossCompute(model.generator, criterion, None)) print("test_loss", test_loss) torch.save(model.statedict(), 'checkpoint/%08d%f.pth'%(epoch, test_loss))

the evaluation should not be different from training, but in this implementation, he uses the same method.

gussmith commented 4 years ago

@delveintodetail not clear what you are presenting.

In predict.py file, there is a model.eval() and there is the same in train.py:

    for epoch in range(10000):
        model.train()
        run_epoch(train_dataloader, model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
        model.eval()
        test_loss = run_epoch(val_dataloader, model, 
              SimpleLossCompute(model.generator, criterion, None))
        print("test_loss", test_loss)
        torch.save(model.state_dict(), 'checkpoint/%08d_%f.pth'%(epoch, test_loss))

and you are stating, I guess, that model.eval() should be different (not "should not be different" (?)) in the train.py and predict.py, but that here they are the same.

Why should they be different?

Pay20Y commented 4 years ago

I guess @delveintodetail means teacher forcing. However, in predict.py teacher forcing is not adopted, so I don't think there are bugs in predict.py