delip / PyTorchNLPBook

Code and data accompanying Natural Language Processing with PyTorch published by O'Reilly Media https://amzn.to/3JUgR2L
Apache License 2.0
1.96k stars 799 forks source link

get_num_batches uses integer division in Chapter3:ReviewDataSet #5

Closed seanv507 closed 5 years ago

seanv507 commented 5 years ago

isn't this wrong if data_size not multiple of batch_size? shoudn't it be :


    def get_num_batches(self, batch_size):
        """Given a batch size, return the number of batches in the dataset

        Args:
            batch_size (int)
        Returns:
            number of batches in the dataset
        """
        return int(np.ceil(len(self)/batch_size))
braingineer commented 5 years ago

Hi @seanv507!

The method is designed to round down and return the number of full-sized batches. The final batch won't be full, but rather len(self) % batch_size. The drop_last argument to the DataLoader allows you to specify whether you want this last batch or not. By default, we drop this batch :

def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"):
    """
    A generator function which wraps the PyTorch DataLoader. It will 
      ensure each tensor is on the write device location.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

I hope that clears it up!