mlverse / torchvision

R interface to torchvision
https://torchvision.mlverse.org
Other
62 stars 14 forks source link

adding `.getbatch()` method to `mnist_dataset` dataset generator improves performance markedly #106

Open gavril0 opened 6 months ago

gavril0 commented 6 months ago

The standard dataset_generator for MNIST dataset does not include a .getbatch() method and, as a result, getting a batch is quite slow, at least on CPU.

# dataset root directory
dir <- "./dataset"

# download dataset
train_ds <- mnist_dataset(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)
# dataloader 
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader iterator
train_iter <- train_dl$.iter()
microbenchmark( b <- $.next())

The timings are:

Unit: milliseconds
                    expr     min       lq     mean   median       uq     max neval
 b <- train_iter$.next() 45.5263 47.78455 52.74117 49.19825 52.95675 87.2219   100

As explained in the vignette, the dataloader uses the .getitem() method iteratively to return a batch in absence of a .getmatch() method.

Interestingly, it seems that the .getitem() method mitgh be used as .getbatch() method without any change:

# mnist_dataset .getitem() method
> train_ds$.getitem
function (index) 
{
    img <- self$data[index, , ]
    target <- self$targets[index]
    if (!is.null(self$transform)) 
        img <- self$transform(img)
    if (!is.null(self$target_transform)) 
        target <- self$target_transform(target)
    list(x = img, y = target)
}
<environment: 0x000001527445be68>

It is easy to add a .getbatch() to the exsiting mnist_dataset dataset generator:

# create a new dataset generator that extends mnist_dataset
mnist_dataset2 <- dataset(
  inherit = mnist_dataset,
  .getbatch = function(index) {
    self$.getitem(index)
  }
)

Let's measure the performance with this new dataset generator:

# create a dataset with the new dataset generator
train_ds2 <- mnist_dataset2(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)
# create a dataloder with the new dataset
train_dl2 <- dataloader(train_ds2, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader
train_iter2 <- train_dl2$.iter()
microbenchmark::microbenchmark(train_iter2$.next())
Unit: milliseconds
                expr      min       lq     mean   median       uq     max neval
 train_iter2$.next() 3.995601 4.328151 5.430246 4.601451 4.965501 11.7692   100

The new dataloader is almost 10 times faster!

That saids, it seems that the newdata loader cannot be used in place of train_dl in this example which uses luz to train the network:

fitted <- mnist_module %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy()
    )
  ) %>%
  fit(train_dl, epochs = 1, valid_data = test_dl)

It yields an error message

expected input[1, 28, 128, 28] to have 1 channels, but got 28 channels instead

I don't have a PC with GPU to test whether there is a similar improvement when the data are loaded on the GPU. I also wonder why the .getbatch() function is not always implemented since it seems an easy way to improve performance. Though I did not investigate the origin the error, the luz::fit method should be able to accept data_loader with a .getbatch method.