d2l-ai / d2l-en

Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.
https://D2L.ai
Other
22.45k stars 4.19k forks source link

Incorrect Use of torch.no_grad() in fit_epoch Method in d2l/torch.py::Trainer::fit_epoch #2573

Open caydenwei opened 6 months ago

caydenwei commented 6 months ago

Hello,

I noticed a potential issue in the fit_epoch method in https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py, where loss.backward() is called within a torch.no_grad() block:

self.optim.zero_grad()
with torch.no_grad():
    loss.backward()
    ...

This usage likely prevents the calculation of gradients, as loss.backward() should not be inside a torch.no_grad() block. The correct approach would be:

self.optim.zero_grad()
loss.backward()
...

Here is the original code:

    def fit_epoch(self):
        """Defined in :numref:`sec_linear_scratch`"""
        self.model.train()
        for batch in self.train_dataloader:
            loss = self.model.training_step(self.prepare_batch(batch))
            self.optim.zero_grad()
            with torch.no_grad():
                loss.backward()
                if self.gradient_clip_val > 0:  # To be discussed later
                    self.clip_gradients(self.gradient_clip_val, self.model)
                self.optim.step()
            self.train_batch_idx += 1
        if self.val_dataloader is None:
            return
        self.model.eval()
        for batch in self.val_dataloader:
            with torch.no_grad():
                self.model.validation_step(self.prepare_batch(batch))
            self.val_batch_idx += 1
wzongyu commented 6 months ago

I think it should be

self.optim.zero_grad() 
    loss.backward() 
    with torch.no_grad():
        self.optim.step()
caydenwei commented 6 months ago

I think it should be

self.optim.zero_grad() 
    loss.backward() 
    with torch.no_grad():
        self.optim.step()

Apologies for not being clear earlier. I'm uncertain about the correctness of a specific part of the code found at https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py. Here is the original code:

    def fit_epoch(self):
        """Defined in :numref:`sec_linear_scratch`"""
        self.model.train()
        for batch in self.train_dataloader:
            loss = self.model.training_step(self.prepare_batch(batch))
            self.optim.zero_grad()
            with torch.no_grad():
                loss.backward()
                if self.gradient_clip_val > 0:  # To be discussed later
                    self.clip_gradients(self.gradient_clip_val, self.model)
                self.optim.step()
            self.train_batch_idx += 1
        if self.val_dataloader is None:
            return
        self.model.eval()
        for batch in self.val_dataloader:
            with torch.no_grad():
                self.model.validation_step(self.prepare_batch(batch))
            self.val_batch_idx += 1