mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
495 stars 65 forks source link

stop_iteration_error enumerating dataset #180

Closed skeydan closed 4 years ago

skeydan commented 4 years ago

I'll update this issue as soon as I've found a way to reliably reproduce (haven't yet - occurs with the current state of mnist-dcgan.R though it seems)

+   cat(sprintf("Epoch %d - Loss D: %3f Loss G: %3f\n", epoch, mean(lossd), mean(lossg)))
+ }
[===================================>]  0s Loss D: 1.05004232272195 Loss G: 1.38054530530197
 Error: 
Run `rlang::last_error()` to see where the error occurred. 
10.
stop(fallback) 
9.
signal_abort(cnd) 
8.
rlang::abort(glue::glue(...), class = "stop_iteration_error") at conditions.R#18
7.
stop_iteration_error() at utils-data-sampler.R#131
6.
self$.sampler_iter() at utils-data-dataloader.R#197
5.
self$.next_index() at utils-data-dataloader.R#235
4.
self$.next_data() at utils-data-dataloader.R#203
3.
parent.env(x)$.iter$.next() at utils-data-enum.R#16
2.
`[[.enum_env`(b, 1) 
1.
b[[1]] 
skeydan commented 4 years ago

Haven't been able to repro for 2 days, closing (for now) :-)

jaredlander commented 4 years ago

Not sure if it's related, but I am getting the following error when using enumerate().

rlang::last_error()
<error/stop_iteration_error>
Backtrace:
 1. torch:::mod2(b[[1]])
 4. torch:::`[[.enum_env`(b, 1)
 5. parent.env(x)$.iter$.next()
 6. self$.next_data()
 7. self$.next_index()
 8. self$.sampler_iter()
 9. torch:::stop_iteration_error()
Run `rlang::last_trace()` to see the full context.

I didn't provide any context, but I can make a reprex if this is related and helpful.

skeydan commented 4 years ago

@jaredlander please do :-)

... I never was able to reproduce but I have the feeling there is something going on ... :-)

jaredlander commented 4 years ago

After restarting R I can't recreate it either! But I won't stop trying.

dfalbel commented 4 years ago

reprex or didn't happen 😛

In theory this should only happen if for some reason the length of the dataloader is not correctly calculated. But the code for enumerate is tricky and there might be some edge cases that we didn't catch yet.

jaredlander commented 4 years ago

That was my theory. I've seen similar issues with TensorFlow giving out when the rows didn't align properly. I'll post if I run into that problem again.

skeydan commented 4 years ago

Not quite a reprex, size-wise, but at least some indication of its appearance:

library(torch)
library(torchvision)
library(dplyr)

train_transforms <- function(img) {
  img %>%
    transform_random_resized_crop(size = c(224, 224)) %>%
    transform_color_jitter() %>%
    transform_random_horizontal_flip() %>%
    transform_to_tensor() %>%
    transform_normalize(mean = c(0.485, 0.456, 0.406), std = c(0.229, 0.224, 0.225))
}

valid_transforms <- function(img) {
  img %>%
    transform_resize(256) %>%
    transform_center_crop(224) %>%
    transform_to_tensor() %>%
    transform_normalize(mean = c(0.485, 0.456, 0.406), std = c(0.229, 0.224, 0.225))
}

test_transforms <- valid_transforms

target_transform = function(x) {
  x <- torch_tensor(x, dtype = torch_long())
  x$squeeze(1)
}

# https://www.kaggle.com/gpiosenka/100-bird-species/data
data_dir = 'data/bird_species'

train_ds <- image_folder_dataset(file.path(data_dir, "train"),
                                 transform = train_transforms,
                                 target_transform = target_transform)
valid_ds <- image_folder_dataset(file.path(data_dir, "valid"),
                                 transform = valid_transforms,
                                 target_transform = target_transform)
test_ds <-
  image_folder_dataset(file.path(data_dir, "test"),
                       transform = test_transforms,
                       target_transform = target_transform)

batch_size <- 32
train_dl <- dataloader(train_ds, batch_size = batch_size, shuffle = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = batch_size)
test_dl <- dataloader(test_ds, batch_size = batch_size)

model <- model_resnet18(pretrained = TRUE)
model$parameters %>% purrr::walk(function(param) param$requires_grad <- FALSE)

num_features <- model$fc$in_features

model$fc <- nn_linear(in_features = num_features, out_features = length(class_names))

device <- if (cuda_is_available()) torch_device("cuda:0") else "cpu"

model <- model$to(device = device)

criterion <- nn_cross_entropy_loss()

optimizer <- optim_sgd(model$parameters, lr = 0.001, momentum = 0.9)

num_epochs <- 10

scheduler <- optimizer %>%
  lr_one_cycle(max_lr = 0.05, epochs = num_epochs, steps_per_epoch = train_dl$.length())

for (epoch in 1:num_epochs) {

  model$train()
  train_losses <- c()

  for (b in enumerate(train_dl)) {
    optimizer$zero_grad()
    output <- model(b[[1]]$to(device = "cuda"))
    loss <- criterion(output, b[[2]]$to(device = "cuda"))
    loss$backward()
    optimizer$step()
    scheduler$step()
    train_losses <- c(train_losses, loss$item())
    #print(optimizer$param_groups[[1]]$lr)
  }

  model$eval()
  valid_losses <- c()

  for (b in enumerate(valid_dl)) {
    output <- model(b[[1]])
    loss <- criterion(output, b[[2]]$to(device = "cuda"))
    valid_losses <- c(valid_losses, loss$item())
  }

  cat(sprintf("Loss at epoch %d: training: %3f, validation: %3f\n", epoch, mean(train_losses), mean(valid_losses)))
}

Error: 
Run `rlang::last_error()` to see where the error occurred. 
16.
stop(fallback) 
15.
signal_abort(cnd) 
14.
rlang::abort(glue::glue(..., .envir = env), class = "stop_iteration_error") at conditions.R#22
13.
stop_iteration_error() at utils-data-sampler.R#131
12.
self$.sampler_iter() at utils-data-dataloader.R#197
11.
self$.next_index() at utils-data-dataloader.R#235
10.
self$.next_data() at utils-data-dataloader.R#203
9.
parent.env(x)$.iter$.next() at utils-data-enum.R#16
8.
`[[.enum_env`(b, 1) 
7.
b[[1]] 
6.
mget(x = c("input", "weight", "bias", "stride", "padding", "dilation", 
    "groups")) at gen-namespace.R#5112
5.
torch_conv2d(input = input, weight = weight, bias = bias, stride = stride, 
    padding = padding, dilation = dilation, groups = groups) at nnf-conv.R#47
4.
nnf_conv2d(input, weight, self$bias, self$stride, self$padding, 
    self$dilation, self$groups) at nn-conv.R#327
3.
self$conv_forward_(input, self$weight) at nn-conv.R#332
2.
self$conv1(x) at models-resnet.R#223
1.
model(b[[1]]$to(device = "cuda")) 

I'll leave this open until we chat, later today :-)

dfalbel commented 4 years ago

@skeydan I tried running your example but it went successful: Still not able to reliably reproduce this...

Loss at epoch 1: training: 1.285877, validation: 2.441062
Loss at epoch 2: training: 1.043385, validation: 2.609288
Loss at epoch 3: training: 0.912953, validation: 2.584944
Loss at epoch 4: training: 0.732427, validation: 2.524903
Loss at epoch 5: training: 0.632346, validation: 2.596953
Loss at epoch 6: training: 0.553114, validation: 2.507043
Loss at epoch 7: training: 0.470457, validation: 2.515819
Loss at epoch 8: training: 0.387394, validation: 2.439595
Loss at epoch 9: training: 0.343123, validation: 2.351169
Loss at epoch 10: training: 0.302062, validation: 2.358391
skeydan commented 4 years ago

Did you run the code as pasted above or as was open in my IDE? I pasted that before we talked, and later we guessed it might be due to the iterator being called before?

If you look at https://github.com/mlverse/torchbook/blob/master/scripts/bird_species.R, that's close to what I must have executed that day ...

dfalbel commented 4 years ago

I used the code in the book repository. But I have made some changes to torch enum yesterday, so maybe this fixed?

ozt-ca commented 4 years ago

It looks that #266 may be resolved by the latest update (as far as I've attempted on my local env). Thanks a lot!

skeydan commented 4 years ago

Looks like it works fine now, thanks :-)