mlverse / luz

Higher Level API for torch
https://mlverse.github.io/luz/
Other
84 stars 12 forks source link

Combining early stopping with csv logger callback, metrics for last epoch don't get logged #64

Closed skeydan closed 3 years ago

skeydan commented 3 years ago

To be precise, the code combines

Code:

fitted <- convnet %>%
  setup(
    loss = nn_bce_with_logits_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_binary_accuracy_with_logits()
    )
  ) %>%
  fit(train_dl, epochs = c(5,10), valid_data = valid_dl,
      callbacks = list(luz_callback_early_stopping(), luz_callback_csv_logger("logs.csv")),
      verbose = TRUE)

CSV:

1,"train",0.67555393654533,0.559885714285714
1,"valid",0.643152226962006,0.632266666666667
2,"train",0.613922604913032,0.658685714285714
2,"valid",0.567475510558594,0.706666666666667
3,"train",0.571434665587092,0.701885714285714
3,"valid",0.518853313394828,0.740933333333333
4,"train",0.537284926724172,0.731257142857143
4,"valid",0.481547881831238,0.769066666666667

Console:

Train metrics: Loss: 0.5081 - Acc: 0.7493                                              
Valid metrics: Loss: 0.4434 - Acc: 0.7968
Early stopping at epoch 5 of 10
skeydan commented 3 years ago

I also think that it always stops after min_epochs epochs:

> fitted <- convnet %>%
+   setup(
+     loss = nn_bce_with_logits_loss(),
+     optimizer = optim_adam,
+     metrics = list(
+       luz_metric_binary_accuracy_with_logits()
+     )
+   ) %>%
+   fit(train_dl, epochs = c(5,10), valid_data = valid_dl,
+       callbacks = list(luz_callback_early_stopping(), luz_callback_csv_logger("logs_resnet.csv")),
+       verbose = TRUE)
Epoch 1/10
[W TensorImpl.h:1156] Warning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (function operator())
Train metrics: Loss: 0.1925 - Acc: 0.928                                               
Valid metrics: Loss: 0.0714 - Acc: 0.976
Epoch 2/10
Train metrics: Loss: 0.1368 - Acc: 0.9455                                              
Valid metrics: Loss: 0.0584 - Acc: 0.9776
Epoch 3/10
Train metrics: Loss: 0.1318 - Acc: 0.9479                                              
Valid metrics: Loss: 0.058 - Acc: 0.9765
Epoch 4/10
Train metrics: Loss: 0.1234 - Acc: 0.9496                                              
Valid metrics: Loss: 0.0564 - Acc: 0.9781
Epoch 5/10
Train metrics: Loss: 0.1246 - Acc: 0.9486                                              
Valid metrics: Loss: 0.0508 - Acc: 0.9811
Early stopping at epoch 5 of 10