Closed skeydan closed 4 years ago
Haven't been able to repro for 2 days, closing (for now) :-)
Not sure if it's related, but I am getting the following error when using enumerate()
.
rlang::last_error()
<error/stop_iteration_error>
Backtrace:
1. torch:::mod2(b[[1]])
4. torch:::`[[.enum_env`(b, 1)
5. parent.env(x)$.iter$.next()
6. self$.next_data()
7. self$.next_index()
8. self$.sampler_iter()
9. torch:::stop_iteration_error()
Run `rlang::last_trace()` to see the full context.
I didn't provide any context, but I can make a reprex if this is related and helpful.
@jaredlander please do :-)
... I never was able to reproduce but I have the feeling there is something going on ... :-)
After restarting R I can't recreate it either! But I won't stop trying.
reprex or didn't happen 😛
In theory this should only happen if for some reason the length of the dataloader is not correctly calculated.
But the code for enumerate
is tricky and there might be some edge cases that we didn't catch yet.
That was my theory. I've seen similar issues with TensorFlow giving out when the rows didn't align properly. I'll post if I run into that problem again.
Not quite a reprex, size-wise, but at least some indication of its appearance:
library(torch)
library(torchvision)
library(dplyr)
train_transforms <- function(img) {
img %>%
transform_random_resized_crop(size = c(224, 224)) %>%
transform_color_jitter() %>%
transform_random_horizontal_flip() %>%
transform_to_tensor() %>%
transform_normalize(mean = c(0.485, 0.456, 0.406), std = c(0.229, 0.224, 0.225))
}
valid_transforms <- function(img) {
img %>%
transform_resize(256) %>%
transform_center_crop(224) %>%
transform_to_tensor() %>%
transform_normalize(mean = c(0.485, 0.456, 0.406), std = c(0.229, 0.224, 0.225))
}
test_transforms <- valid_transforms
target_transform = function(x) {
x <- torch_tensor(x, dtype = torch_long())
x$squeeze(1)
}
# https://www.kaggle.com/gpiosenka/100-bird-species/data
data_dir = 'data/bird_species'
train_ds <- image_folder_dataset(file.path(data_dir, "train"),
transform = train_transforms,
target_transform = target_transform)
valid_ds <- image_folder_dataset(file.path(data_dir, "valid"),
transform = valid_transforms,
target_transform = target_transform)
test_ds <-
image_folder_dataset(file.path(data_dir, "test"),
transform = test_transforms,
target_transform = target_transform)
batch_size <- 32
train_dl <- dataloader(train_ds, batch_size = batch_size, shuffle = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = batch_size)
test_dl <- dataloader(test_ds, batch_size = batch_size)
model <- model_resnet18(pretrained = TRUE)
model$parameters %>% purrr::walk(function(param) param$requires_grad <- FALSE)
num_features <- model$fc$in_features
model$fc <- nn_linear(in_features = num_features, out_features = length(class_names))
device <- if (cuda_is_available()) torch_device("cuda:0") else "cpu"
model <- model$to(device = device)
criterion <- nn_cross_entropy_loss()
optimizer <- optim_sgd(model$parameters, lr = 0.001, momentum = 0.9)
num_epochs <- 10
scheduler <- optimizer %>%
lr_one_cycle(max_lr = 0.05, epochs = num_epochs, steps_per_epoch = train_dl$.length())
for (epoch in 1:num_epochs) {
model$train()
train_losses <- c()
for (b in enumerate(train_dl)) {
optimizer$zero_grad()
output <- model(b[[1]]$to(device = "cuda"))
loss <- criterion(output, b[[2]]$to(device = "cuda"))
loss$backward()
optimizer$step()
scheduler$step()
train_losses <- c(train_losses, loss$item())
#print(optimizer$param_groups[[1]]$lr)
}
model$eval()
valid_losses <- c()
for (b in enumerate(valid_dl)) {
output <- model(b[[1]])
loss <- criterion(output, b[[2]]$to(device = "cuda"))
valid_losses <- c(valid_losses, loss$item())
}
cat(sprintf("Loss at epoch %d: training: %3f, validation: %3f\n", epoch, mean(train_losses), mean(valid_losses)))
}
Error:
Run `rlang::last_error()` to see where the error occurred.
16.
stop(fallback)
15.
signal_abort(cnd)
14.
rlang::abort(glue::glue(..., .envir = env), class = "stop_iteration_error") at conditions.R#22
13.
stop_iteration_error() at utils-data-sampler.R#131
12.
self$.sampler_iter() at utils-data-dataloader.R#197
11.
self$.next_index() at utils-data-dataloader.R#235
10.
self$.next_data() at utils-data-dataloader.R#203
9.
parent.env(x)$.iter$.next() at utils-data-enum.R#16
8.
`[[.enum_env`(b, 1)
7.
b[[1]]
6.
mget(x = c("input", "weight", "bias", "stride", "padding", "dilation",
"groups")) at gen-namespace.R#5112
5.
torch_conv2d(input = input, weight = weight, bias = bias, stride = stride,
padding = padding, dilation = dilation, groups = groups) at nnf-conv.R#47
4.
nnf_conv2d(input, weight, self$bias, self$stride, self$padding,
self$dilation, self$groups) at nn-conv.R#327
3.
self$conv_forward_(input, self$weight) at nn-conv.R#332
2.
self$conv1(x) at models-resnet.R#223
1.
model(b[[1]]$to(device = "cuda"))
I'll leave this open until we chat, later today :-)
@skeydan I tried running your example but it went successful: Still not able to reliably reproduce this...
Loss at epoch 1: training: 1.285877, validation: 2.441062
Loss at epoch 2: training: 1.043385, validation: 2.609288
Loss at epoch 3: training: 0.912953, validation: 2.584944
Loss at epoch 4: training: 0.732427, validation: 2.524903
Loss at epoch 5: training: 0.632346, validation: 2.596953
Loss at epoch 6: training: 0.553114, validation: 2.507043
Loss at epoch 7: training: 0.470457, validation: 2.515819
Loss at epoch 8: training: 0.387394, validation: 2.439595
Loss at epoch 9: training: 0.343123, validation: 2.351169
Loss at epoch 10: training: 0.302062, validation: 2.358391
Did you run the code as pasted above or as was open in my IDE? I pasted that before we talked, and later we guessed it might be due to the iterator being called before?
If you look at https://github.com/mlverse/torchbook/blob/master/scripts/bird_species.R, that's close to what I must have executed that day ...
I used the code in the book repository. But I have made some changes to torch
enum yesterday, so maybe this fixed?
It looks that #266 may be resolved by the latest update (as far as I've attempted on my local env). Thanks a lot!
Looks like it works fine now, thanks :-)
I'll update this issue as soon as I've found a way to reliably reproduce (haven't yet - occurs with the current state of
mnist-dcgan.R
though it seems)