ContinualAI / avalanche

Avalanche: an End-to-End Library for Continual Learning based on PyTorch.
http://avalanche.continualai.org
MIT License
1.78k stars 290 forks source link

ReplayPlugin behaves badly when buffer_size > experience size #1274

Closed AlbinSou closed 1 year ago

AlbinSou commented 1 year ago

🐛 Describe the bug This does not make the code fail with an error, but it renders it very slow, and results in a bad behavior that is not the one expected by the user.

When using ReplayPlugin together with an online scenario, replay is activated starting from the second iteration, and thus the training batch will be twice as big starting from this iteration. This is the expected behavior. However, what happens currently is that the strategy trains on the whole buffer additionally, which I don't think is what is suppose to happen.

More precisely, here is the training loop inside SGDUpdate training_epoch:

       for self.mbatch in self.dataloader:
           [...]

Imagine we set a batch size of 10 and an experience dataset size of 10, with a memory of 1000. We should have len(self.dataloader) = 1 for every seen batch (i.e experience). However, the current behavior is that we have len(self.dataloader) = min(num_batches_seen_so_far, buffer_size/batch_size). So the number of iteration per batch increases, first 1, then 2, 3, 4, 5 .... limited by the size of the buffer. But it basically iterates over the whole buffer every time before moving on to the next batch

As I initially encountered this bug in the online setting (where this effect is highlighted), I reported it using an online scenario (see below). However, I also noticed that in the non-online setting, when the buffer size is bigger than the current experience dataset size, the number of iterations is increased (this time it results in a visible increase of the number of iterations in the interactivelogger)

🐜 To Reproduce

import numpy as np
import torch
import torchvision.transforms as transforms
from torch.optim import SGD

from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.benchmarks.generators import benchmark_with_validation_stream
from avalanche.benchmarks.scenarios.online_scenario import OnlineCLScenario
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.models import SlimResNet18
from avalanche.models.dynamic_modules import IncrementalClassifier
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.supervised import OnlineNaive
from experiments.utils import create_default_args, set_seed

def replay_scifar10(override_args=None):
    args = create_default_args(
        {
            "cuda": 0,
            "mem_size": 1000,
            "lr": 0.1,
            "train_mb_size": 10,
            "seed": 0,
            "batch_size_mem": 10,
        },
        override_args
    )
    set_seed(args.seed)
    fixed_class_order = np.arange(10)
    device = torch.device(
        f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
    )

    scenario = SplitCIFAR10(
        5,
        return_task_id=False,
        seed=args.seed,
        fixed_class_order=fixed_class_order,
        shuffle=True,
        class_ids_from_zero_in_each_exp=False,
    )

    scenario = benchmark_with_validation_stream(scenario, 0.05)
    input_size = (3, 32, 32)
    model = SlimResNet18(10)
    optimizer = SGD(model.parameters(), lr=args.lr)

    interactive_logger = InteractiveLogger()

    loggers = [interactive_logger]

    training_metrics = []

    evaluation_metrics = [
        accuracy_metrics(epoch=True, stream=True),
        loss_metrics(epoch=True, stream=True),
    ]

    # Create main evaluator that will be used by the training actor
    evaluator = EvaluationPlugin(
        *training_metrics,
        *evaluation_metrics,
        loggers=loggers,
    )

    plugins = [ReplayPlugin(mem_size=args.mem_size)]

    #######################
    #  Strategy Creation  #
    #######################

    cl_strategy = OnlineNaive(
        model=model,
        optimizer=optimizer,
        plugins=plugins,
        evaluator=evaluator,
        device=device,
        train_mb_size=args.train_mb_size,
        eval_mb_size=64,
    )

    ###################
    #  TRAINING LOOP  #
    ###################

    print("Starting experiment...")

    print([p.__class__.__name__ for p in cl_strategy.plugins])

    # For online scenario
    batch_streams = scenario.streams.values()

    for t, experience in enumerate(scenario.train_stream):
        print("Start of experience: ", experience.current_experience)
        print("Current Classes: ", experience.classes_in_this_experience)

        ocl_scenario = OnlineCLScenario(
            original_streams=batch_streams,
            experiences=experience,
            experience_size=10,
            access_task_boundaries=False,
        )

        # Set this to inform strat
        cl_strategy.classes_in_this_experience = experience.classes_in_this_experience

        cl_strategy.train(
            ocl_scenario.train_stream,
            eval_streams=[],
            num_workers=0,
            drop_last=True,
        )

        cl_strategy.eval(scenario.test_stream[: t + 1])

    # Only evaluate at the end on the test stream
    results = cl_strategy.eval(scenario.test_stream)

    return results

if __name__ == "__main__":
    res = replay_scifar10()
    print(res)

The experiments package is the one from continual-learning-baselines repo.

This example runs very slowly, but way faster if you remove the ReplayPlugin from plugins. Also, if you set a breakpoint into the SGDUpdate training loop, you will see that the size of the dataloader augments whereas it should be the same for each new batch (i.e experience).

🐝 Expected behavior

The plugin should not modify the number of training iterations specified by the user, i.e when training on 1000 samples with batch size 10 and one training pass, self.backward should be called exactly 100 times. Whereas here it is called n*(n+1)/2 time with infinite memory, leading to very long running time and non-expected behavior.

The correct behavior should be equivalent to sampling a batch from the current dataset, and concatenating to it a batch from the memory. However, when we run out of batches from the current dataset, no more training iterations should be performed.

HamedHemati commented 1 year ago

Thanks @AlbinSou, for the detailed description of the problem:

I actually had similar issues before and opened a pull request to add support for "interactive" buffer mode in the ReplayPlugin, which could be a solution to this issue as well. You were also involved in that discussion. In that case, buffer samples are added to each mini-batch by sampling from the buffer. You can have a look here: #1198. But based I understood from that conversation, the idea of the ReplayBuffer is not to have a general-purpose replay plugin that fits all conditions. So maybe it's better to have a separate replay plugin for online strategies[?]

Nevertheless, the reason why the number of iterations increases is that ReplayLoader, which is used by the ReplayPlugin, creates two parallel loaders, one for the current experience samples and one for the buffer samples, and iterates through them in parallel. The minimum length of the replay dataloader is calculated according to the length of the longer loader between the experience and buffer loaders. Therefore, if the number of buffer samples is larger than the experience samples, the strategy will have more iterations, as you explained.

We can decide on which type of solution would be better, to change the existing ReplayPlugin? or to have a different ReplayPlugin for the online scenario?

AlbinSou commented 1 year ago

Hello @HamedHemati. Yes I remember this issue, but I think I had not understood that this behavior was existing and I thought everything was going fine. Since when you use it in the online setting, the right number of iteration is shown by the interactive logger, which is hiding the problem ! I had to run profiling to realize much more backward passes where happening under the hood.

I don't know what is the "good" solution but I think in general it is sufficient and more natural to iterate over the current data, and for each batch of current data, concatenate a batch from the replay buffer.

So, I think the best way would be to, instead of taking the biggest loader as a reference, take the current experience loader as the only reference, and cycle over the replay loader. When doing so, even if we do not see every replay example during the training of one experience, we will iterate over all of them eventually if enough epochs are used, since the replay loader is shuffling the replay data.

I don't know if the best solution is to propose an option to resort to the current behavior or to not use the current behavior at all and replace it by the "online-like" option. I honestly don't think the current behavior is required for any use case, but I might be wrong.

HamedHemati commented 1 year ago

In this strategy: https://github.com/ContinualAI/avalanche/blob/master/avalanche/training/supervised/mer.py , which is an online strategy, I implemented a simple buffer that concatenates samples before each training iteration. If you need a customized replay plugin, the easiest way would be to implement it yourself, which can be done quickly by using existing storage policies.

For the AVL ReplayPlugin, as you said we could add additional options to control which loader should be set as the "main" loader to avoid similar issues in online strategies. But I'm not sure if even that additional option would still make the ReplayPlugin a good fit for all strategies. We need to find a good trade-off between the complexity and the "universality" of the ReplayPlugin.

AntonioCarta commented 1 year ago

Do we need to add an option? If I understand correctly, right now we use max(len(buffer), len(dataset) to choose the iteration, which works in batch CL but not in online CL. If we use len(dataset) only it should work as most people expect in batch and OCL. In fact, I think that even in batch mode if len(buffer) > len(dataset) people would not expect the number of iterations to increase.

Let me know if I'm missing something. Otherwise I think we can consider this a bug (or undesired behavior) of the current replay dataloader and switch to always using len(dataset).

HamedHemati commented 1 year ago

@AntonioCarta Yea, that's how it works now. max_len is the maximum of loader lengths: https://github.com/ContinualAI/avalanche/blob/master/avalanche/benchmarks/utils/data_loader.py#L480

max_len = max(
            [
                len(d)
                for d in chain(
                    loader_data.values(),
                    loader_memory.values(),
                )
            ]
        )

It ensures that the loader goes through all samples in the experience dataset and the buffer at least once, which doesn't have to be like that. It should be a simple fix if we want to change the default behavior.

AlbinSou commented 1 year ago

Do we need to add an option? If I understand correctly, right now we use max(len(buffer), len(dataset) to choose the iteration, which works in batch CL but not in online CL. If we use len(dataset) only it should work as most people expect in batch and OCL. In fact, I think that even in batch mode if len(buffer) > len(dataset) people would not expect the number of iterations to increase.

Let me know if I'm missing something. Otherwise I think we can consider this a bug (or undesired behavior) of the current replay dataloader and switch to always using len(dataset).

I think it makes more sense to consider that as a bug and always use len(loader_data) as a reference.

AntonioCarta commented 1 year ago

@AndreaCossu is it possible that this is causing some failures in the continual-learning-baselines repo? How many scripts do we have that fall under this case? All the OCL methods for sure (e.g. CoPE).

AlbinSou commented 1 year ago

@AndreaCossu is it possible that this is causing some failures in the continual-learning-baselines repo? How many scripts do we have that fall under this case? All the OCL methods for sure (e.g. CoPE).

Not necessarly all of them. Since some are not based on the ReplayPlugin but rather reimplement it themselves (MIR, AGEM, GEM). But I think CoPE might be affected, since it uses ReplayDataLoader

AlbinSou commented 1 year ago

@AntonioCarta @HamedHemati I fixed the problem on my local branch (max_len = len(loader_data)).

It indeed fixes the slowness problem but only partly. I still have slight worsening of the rapidity of the replay with increasing tasks. For instance on CIFAR10 5 tasks, online, and with the ClassBalancedBuffer storage policy, I go from 40it/s in the first task, to 34, 29, 23 and 18 it/s. Though overall it is way quicker than without this change.

If I use ExperienceBalancedBuffer (which is the default one), it is way slower (which I can understand why since in the online setting the number of experiences grows way faster than the number of classes.

If you want I can make a PR for the change. However I don't really know what this loader_data is made of, right now I am just taking the len(loader_data) but the code seems to assume that several loaders can be inside of that object. So I'd rather be sure that my solution covers every situation.

AntonioCarta commented 1 year ago

just to confirm, is the batch size constant?

If I use ExperienceBalancedBuffer (which is the default one), it is way slower

This may hint at the underlying issue. In this case there will be one dataloader for each experience. I don't think we ever tested how expensive this is but we may have to implement a faster solution with a single dataloader. It's not complex but it's a bit more work.

AlbinSou commented 1 year ago

just to confirm, is the batch size constant?

Yes, the batch size is constant, 10 from the current data and 10 from the memory

AntonioCarta commented 1 year ago

Ok, then I suspect the overhead comes from the multiple dataloaders.

AlbinSou commented 1 year ago

Ok, then I suspect the overhead comes from the multiple dataloaders.

Well, maybe, I don't really know how this works. What I don't understand is what is inside of loader_data.values() apart from the one loader for the current task. Then if I understand what you say I guess in loader_memory there is one loader for each class or experience depending on what storage policy is used.

HamedHemati commented 1 year ago

@AlbinSou are you setting the batch_size and mem_batch_size correctly?

When you create ReplayPlugin you need to set the batch size according to the strategy's batch size, or just leave it None and it will use the strategy's batch size. I tested, and the number of iterations is just fine. It runs quickly without additional iterations.

replay_plugin = ReplayPlugin(mem_size=500, storage_policy=storage_policy )

In my case, I set both the mini-experience size and train_mb_size to 10:

cl_strategy = OnlineNaive(
        model,
        torch.optim.Adam(model.parameters(), lr=0.1),
        CrossEntropyLoss(),
        train_passes=1,
        train_mb_size=10,
        eval_mb_size=32,
        device=device,
        evaluator=eval_plugin,
        plugins=[replay_plugin],
    )
...
    for i, exp in enumerate(benchmark.train_stream):
        # Create online scenario from experience exp
        ocl_benchmark = OnlineCLScenario(
            original_streams=batch_streams, experiences=exp, experience_size=10
        )
        # Train on the online train stream of the scenario
        cl_strategy.train(ocl_benchmark.train_stream)
        results.append(cl_strategy.eval(benchmark.test_stream))
HamedHemati commented 1 year ago

*For some reason, I didn't see the comments after the update from @AlbinSou when I posted my previous comment. Now I see that you've already discussed the batch size matter. As already discussed, the problem is now not related to the number of iterations but probably either how the memory loaders are created or how the buffer is updated within the policy.

Well, maybe, I don't really know how this works. What I don't understand is what is inside of loader_data.values() apart from the one loader for the current task. Then if I understand what you say I guess in loader_memory there is one loader for each class or experience depending on what storage policy is used.

If you set task_balanced_dataloader=False in the ReplayPlugin, it will create a single dataloader for all tasks, otherwise there will be one per task.

AlbinSou commented 1 year ago

@AlbinSou are you setting the batch_size and mem_batch_size correctly?

When you create ReplayPlugin you need to set the batch size according to the strategy's batch size, or just leave it None and it will use the strategy's batch size. I tested, and the number of iterations is just fine. It runs quickly without additional iterations.

replay_plugin = ReplayPlugin(mem_size=500, storage_policy=storage_policy )

In my case, I set both the mini-experience size and train_mb_size to 10:

cl_strategy = OnlineNaive(
        model,
        torch.optim.Adam(model.parameters(), lr=0.1),
        CrossEntropyLoss(),
        train_passes=1,
        train_mb_size=10,
        eval_mb_size=32,
        device=device,
        evaluator=eval_plugin,
        plugins=[replay_plugin],
    )
...
    for i, exp in enumerate(benchmark.train_stream):
        # Create online scenario from experience exp
        ocl_benchmark = OnlineCLScenario(
            original_streams=batch_streams, experiences=exp, experience_size=10
        )
        # Train on the online train stream of the scenario
        cl_strategy.train(ocl_benchmark.train_stream)
        results.append(cl_strategy.eval(benchmark.test_stream))

Hmm, it's curious that this example works. Are you sure that you don't see a slow down in the first epochs ? In my case I do not see the number of iterations increase by looking at the logger but it's when I look at the number of backward passes that I realized more backward passes than what I expected were happening. Also, with only 500 memory the problem will be limited since worst case the algorithm will do 50 iterations per iteration (which is still a lot, but even more with more memory).