mmasana / FACIL

Framework for Analysis of Class-Incremental Learning with 12 state-of-the-art methods and 3 baselines.
https://arxiv.org/pdf/2010.15277.pdf
MIT License
524 stars 99 forks source link

could I omit the validation dataset? #19

Closed libo-huang closed 2 years ago

libo-huang commented 2 years ago

Many thanks for your helpful project.

I want to know if there is possible to omit the validation dataset since it will reduce the number of training datasets as shown below, https://github.com/mmasana/FACIL/blob/f653d6c0eef52292dd610f3fd412e29315a93ed2/src/datasets/memory_dataset.py#L100-L110

I have tried to change the default value of the controlled parameter detailed below, https://github.com/mmasana/FACIL/blob/f653d6c0eef52292dd610f3fd412e29315a93ed2/src/datasets/data_loader.py#L14 but there raised ZeroDivisionError.

mmasana commented 2 years ago

Hi, happy you enjoy the project :)

The main issue with removing the validation is that there are some tools used in the framework that by default make use of it. Since the training uses early stopping based on the direction dictated by the validation set (and so does the Continual Hyperparameter Framework), it would change the way of training is intended. I do not recommend using it without a validation set, but if you want to do it, I can provide a quick solution.

As you mention, you can fix the validation parameter to 0.0. By doing that, the validation dataset will be empty and the loader will also be empty, therefore it doesn't make much sense to pass that to most functions. The reason why you get the ZeroDivisionError is due to calculating the accuracy and trying to divide by the total number of validation samples, which are zero. An easy fix for that is to define that val_loader = trn_loader (e.g. at line 194 of main_incremental.py), therefore using the training set as both train and validation. This should avoid getting any errors due to the val_loader being empty, and will use the training instead. There are other ways to modify the framework to not use validation depending on your specific use case, but this one would be the quickest way to fix it.

That would go in the lines after getting the loaders: https://github.com/mmasana/FACIL/blob/f653d6c0eef52292dd610f3fd412e29315a93ed2/src/main_incremental.py#L189-L196

libo-huang commented 2 years ago

Thanks for your prompt reply. It does work although it will nearly double the training time.

mmasana commented 2 years ago

Strange, the validation is only used during inference, so it should not double the training time. What you can further do is reduce the amount of data in the val_loader to only be a subset of the train set. So you will train with 100% of the data, and use e.g. 10% of it as both train and validation. That should cut your current inference time for the validation by one order of magnitude.

Roughly, you can do that in the line below the val_loader = trn_loader by adding something like:

reduced_val = 0.1 * len(val_loader.dataset.labels)
val_loader.dataset.images = val_loader.dataset.images[:reduced_val]
val_loader.dataset.labels = val_loader.dataset.labels[:reduced_val]

Edit: just realized that since the loaders are lists, you will have to do it for each element of the list. It's just adding the loop.

libo-huang commented 2 years ago

the validation is only used during inference, so it should not double the training time. Yes, it is definitely right, but each batch of data should be used for inference, the reduced time is only without the backpropagation for inference compared with traingin. So I hold nearly.

Just realized that since the loaders are lists, you will have to do it for each element of the list. It's just adding the loop

val_loader = trn_loader
for val_i in range(len(val_loader)):
    reduced_val = int(0.1 * len(val_loader[val_i].dataset.labels))
    val_loader[val_i].dataset.images = val_loader[val_i].dataset.images[:reduced_val]
    val_loader[val_i].dataset.labels = val_loader[val_i].dataset.labels[:reduced_val]

It is good enough for me now, sincerely thanks again for your help.

libo-huang commented 2 years ago

Oh, my god, val_loader = trn_loader should be fixed with val_loader = copy.deepcopy(trn_loader), then the final solution about this question is:

mmasana commented 2 years ago

Perfect! Thanks for pointing out the deepcopy!