fengxinjie / Transformer-OCR

MIT License
320 stars 74 forks source link

testloss would lead to model update on eval mode #7

Open li10141110 opened 4 years ago

li10141110 commented 4 years ago

Is this a bug? image

gussmith commented 4 years ago

@li10141110 in test_loss, the self.opt is set to None

        model.eval()
        test_loss = run_epoch(val_dataloader, model, 
              SimpleLossCompute(model.generator, criterion, None))

Thus the if self.opt is not None: is not performed. What is the issue then?

Sanster commented 4 years ago

In test_loss loss.backward() is performed, this mean grad is calculated. Then in train_loss, grad calculated by test_loss is used by opt.step()

li10141110 commented 4 years ago

@Sanster Thank for answer to @gussmith's question😄

gussmith commented 4 years ago

Wouldn't it be simply recalculated before opt.step()? Or you are suggesting it is being accumulated and therefore is updating weights? So it should read instead ?:

class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        #loss.backward()
        if self.opt is not None:
            loss.backward()
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.data * norm
gussmith commented 4 years ago

This part of the code, as the authors indicate is the same as in: https://nlp.seas.harvard.edu/2018/04/03/attention.html

gussmith commented 4 years ago

The original code calculates gradients during validation. I verified that this trains the model towards the validation data, which should not happen.

With the code below, the model improves quite faster on the training data set, but does not work so well on the validation dataset (another bug somewhere else?). One major advantage is that with the changes below, the code goes through the validation set twice as fast as before because it does not do loss.backward().

The recommended way to run validation is:

model.eval()
torch.no_grad()

https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615 https://discuss.pytorch.org/t/two-small-questions-about-with-torch-no-grad/27571

And indeed, the code that runs without error is the following:

        model.eval()
        with torch.no_grad():
            test_loss = run_epoch(val_dataloader, model, 
                  SimpleLossCompute(model.generator, criterion, None))
            print("test_loss", test_loss)

and:

class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        #loss.backward()
        if self.opt is not None:
            loss.backward()
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.data * norm
ppwwyyxx commented 4 years ago

Or you are suggesting it is being accumulated and therefore is updating weights

loss.backward() does accumulate gradients and is clearly written in pytorch documentation. So it's very clear that this code trains models on the test set.