mlverse / luz

Higher Level API for torch
https://mlverse.github.io/luz/
Other
84 stars 12 forks source link

add mixup callback and mixup cross entropy loss #82

Closed skeydan closed 2 years ago

skeydan commented 3 years ago

Hey @dfalbel I think this went pretty well! A few comments/questions:

1)

In documenting the callback, I now have

#' Implementation of https://arxiv.org/abs/1710.09412.
#' As of today, tested only for categorical data,
#' where targets are expected to be integers, not one-hot encoded vectors.

It's a lot more generic now, and could work for other data as well, but before we've actually tested that it probably makes sense to warn users ... You agree?

2)

In the loss module, instead of this state-modifying thing

self$loss$reduction <- "none"

I tried cloning the loss (module$clone()) and having 2 separate instance fields instead ... but that yielded an error ("attempt to apply non-function"). What's your opinion on this? Do you find it ugly (the above)?

3)

Running the tests, I saw the testthat output of "skipped" for the callback tests (because there's no actual expectation being tested). This could be misleading to others, I guess (they might think there's no code in there, in fact), so I wrapped everything in expect_silent(). Does that make sense, or do you have another suggestion?

Thanks!

dfalbel commented 3 years ago

I just made a commit slightly improving the docs and made a small change for 2. For 3. Ideally there's a more precise assertion to check that callback is actually used. How do the fastai folks test this?

skeydan commented 3 years ago

thanks @dfalbel ! Just corrected a typo :-)

As for testing, like we talked about, it's mostly visual testing ;-): see https://github.com/fastai/fastai/blob/master/nbs/19_callback.mixup.ipynb

skeydan commented 3 years ago

hey @dfalbel, I refactored this a bit and extracted the main mixup functionality into its own function, so it's testable standalone (and its effects can be easily inspected)

skeydan commented 3 years ago

Renamed to nnf_mixup(), like we said.

skeydan commented 2 years ago

Hm... I was hoping this might be ready to merge, but it's not ...

The test cases don't compute accuracy, but definitely they should, because there is a dependency:

dl <- get_categorical_dl(x_size = 768)

  model <- get_model()
  expect_silent({
  mod <- model %>%
    setup(
      loss = nn_mixup_loss(torch::nn_cross_entropy_loss(ignore_index = 222)),
      optimizer = torch::optim_adam,
      metrics = list(luz_metric_accuracy())
    ) %>%
    set_hparams(input_size = 768, output_size = 10) %>%
    fit(dl, verbose = FALSE, epochs = 2, valid_data = dl,
        callbacks = list(luz_callback_mixup()))
  })
Error in cpp_torch_tensor(data, rev(dimension), options, requires_grad, : 
R type not handled
28.
stop(structure(list(message = "R type not handled", call = cpp_torch_tensor(data, 
rev(dimension), options, requires_grad, inherits(data, "integer64")), 
cppstack = structure(list(file = "", line = -1L, stack = c("/home/key/R/x86_64-redhat-linux-gnu-library/4.1/torch/libs/torchpkg.so(+0x2c8aa9) [0x7f0a5a4d2aa9]", 
"/home/key/R/x86_64-redhat-linux-gnu-library/4.1/torch/libs/torchpkg.so(+0x2953ce) [0x7f0a5a49f3ce]", ... at RcppExports.R#10617
27.
cpp_torch_tensor(data, rev(dimension), options, requires_grad, 
inherits(data, "integer64")) at tensor.R#39
26.
methods$initialize(...) at R7.R#31
25.
Tensor$new(data, dtype, device, requires_grad, pin_memory) at tensor.R#307
24.
torch_tensor(e2, device = e1$device) at operators.R#124
23.
`==.torch_tensor`(pred, target) at metrics.R#113
22.
x$update(ctx$pred, ctx$target) at callbacks.R#244
21.
FUN(X[[i]], ...)
20.
lapply(ctx$metrics$train, function(x) x$update(ctx$pred, ctx$target)) at callbacks.R#242
19.
self[[callback_nm]]() at callbacks.R#14
18.
callback$call(name) at callbacks.R#23
17.
FUN(X[[i]], ...)
16.
lapply(callbacks, function(callback) {
callback$call(name)
}) at callbacks.R#22
15.
force(code)
14.
(withr::with_(set = function() {
cpp_autograd_set_grad_mode(FALSE)
}, reset = function(old) {
cpp_autograd_set_grad_mode(current_mode) ... at autograd.R#64
13.
torch::with_no_grad({
lapply(callbacks, function(callback) {
callback$call(name)
}) ... at callbacks.R#21
12.
call_all_callbacks(self$callbacks, name) at context.R#158
11.
ctx$call_callbacks("on_train_batch_end") at module.R#228
10.
eval_bare(loop, env)
9.
coro::loop(for (batch in ctx$data) {
ctx$batch <- batch
ctx$iter <- ctx$iter + 1L
ctx$call_callbacks("on_train_batch_begin") ... at module.R#221
8.
doTryCatch(return(expr), name, parentenv, handler)
7.
tryCatchOne(expr, names, parentenv, handlers[[1L]])
6.
tryCatchList(expr, classes, parentenv, handlers)
5.
tryCatch(.expr, interrupt = function (err) 
{
ctx$call_callbacks("on_interrupt")
invisible(NULL) ...
4.
rlang::with_handlers(!!!ctx$handlers, .expr = {
for (epoch in seq_len(ctx$max_epochs)) {
ctx$epoch <- epoch
ctx$iter <- 0L ... at module.R#209
3.
fit.luz_module_generator(., dl, verbose = FALSE, epochs = 2, 
valid_data = dl, callbacks = list(luz_callback_mixup()))
2.
fit(., dl, verbose = FALSE, epochs = 2, valid_data = dl, callbacks = list(luz_callback_mixup()))
1.
model %>% setup(loss = nn_mixup_loss(torch::nn_cross_entropy_loss(ignore_index = 222)), 
optimizer = torch::optim_adam, metrics = list(luz_metric_accuracy())) %>% 
set_hparams(input_size = 768, output_size = 10) %>% fit(dl, 
verbose = FALSE, epochs = 2, valid_data = dl, callbacks = list(luz_callback_mixup()))

This makes sense, since just like the loss, the metric should work differently for training and test sets ... do you have some advice here? Spontaneously all I can think of is rather ugly.

dfalbel commented 2 years ago

How should the accuracy be computed in that case? Should we use the transformed x with the true y? I that case, maybe we should instead of returning a list() with y1, y2 and weight to target we could add that information as attributes to ctx$batch$y, I mean something like this:

y2 <- self$ctx$batch$y[shuffle]
attr(ctx$batch$y, "y2") <- y2
attr(ctx$batch$y, "weight") <- weight

And then modify nn_mixup_loss to retrieve them.

skeydan commented 2 years ago

No, the idea is (and nn_mixup_loss() does this) to compute the loss using both targets, weighted by the mixing weights. This is why we have the list of lists, comprising both targets and the weights.

Looking at how the metric is defined, perhaps I should add a custom metric, luz_metric_acc_mixup(), for example?

skeydan commented 2 years ago

From what we discussed, one can argue that luz_metric_accuracy() is not to be used with this callback. Instead, one can assess training progress from the losses (on training and test set, respectively), and do an evaluation afterwards.

In that light, I think the PR could be merged.