mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
490 stars 64 forks source link

dataset subset introducing NAs by coercion #1116

Open dominicwhite opened 11 months ago

dominicwhite commented 11 months ago

When I use the dataset_subset() function, I get the following error when later training the model with the fit() function:

Warning: NAs introduced by coercion to integer rangeError in (function (self, target, weight, reduction, ignore_index, label_smoothing) : Evaluation error: missing replacement values are not allowed.

Reproducible example:

library(torch)
library(torchvision)
library(luz)

train_ds <- kmnist_dataset(
  "imagesk", 
  download = TRUE,
  transform = . %>%
    transform_to_tensor() %>%
    torch_flatten()
  )

train_ds <- dataset_subset(train_ds, indices=1:1000)
valid_ds <- dataset_subset(train_ds, indices=1001:1500)

train_dl <- dataloader(
  train_ds, 
  batch_size = 32,
  shuffle = TRUE
  )

valid_dl <- dataloader(
  valid_ds,
  batch_size = 32
  )

net <- nn_module(
  "onelayer",
  initialize = function() {
    self$net <- nn_sequential(
      nn_linear(784,128),
      nn_relu(),
      nn_linear(128,10)
    )
  },
  forward = function(x) {
    self$net(x)
  }
)

model1 <- net %>%
  setup(
    loss = nn_cross_entropy_loss(), 
    optimizer = optim_adam, 
    metrics = list(
      luz_metric_accuracy()
    )
  )

fitted1 <- fit(
  model1,
  train_dl,
  epochs = 2,
  valid_data = valid_dl,
  verbose = TRUE
)

I have found three separate "solutions" that each seem to that fix this issue and allow the model to train without that error:

train_ds <- dataset_subset(train_ds, indices=1001:2024)
valid_ds <- dataset_subset(train_ds, indices=1:512)

However, I'm not sure why the original code shouldn't work? Why would switching the subset indices (my third solution) fix this?

dfalbel commented 9 months ago

Sorry @dominicwhite for taking so long to look at this issue. I tried running your reproducible example using the dev version of torch and luz and could not reproduce. I feel like this could be related to something like https://github.com/mlverse/torch/issues/961