ContinualAI / avalanche

Avalanche: an End-to-End Library for Continual Learning based on PyTorch.
http://avalanche.continualai.org
MIT License
1.79k stars 292 forks source link

strategy API design #1611

Closed AntonioCarta closed 5 months ago

AntonioCarta commented 8 months ago

@AlbinSou to continue the discussion that we had on Monday, this is a pseudocode of what I imagined a "standalone" Avalanche strategy would be like. The idea is that the whole training loop must fit in a single file. Components (optimizer, dataloaders, losses, models) are all independent and composable. Components are updatable object that receive the strategy state and current experience only at predefined points. This mitigates the issue that many have with Avalanche strategies, that is having to navigate the template hierarchy to understand the training loop.

If we go in this direction, it is important to do it in a manner that is cheap to maintain without ending up with broken code over the next few releases (like it happened with many examples).

Some things will not be easily definable as simple "updatable" objects. This is fine, it just means that they will be more integrated in their training loop.

Let me know what you think. Of course we don't need to convert what we already have, this would be mostly to support usage of Avalanche together with other frameworks. For example, this would make it possible to use Avalanche in Lightning.

def my_cool_method(exp, agent_state=None, **extra_args):
    # defining the training for a single exp. instead of the whole stream
    # should be better because this way training does not need to care about eval loop.
    # agent_state takes the role of the strategy object

    if agent_state is None:
        # first exp, init stuff
        # we could even do partial initialization (if not too tricky)
        # maybe_init does not touch already defined arguments
        agent_state.replay = maybe_init(MyReplayBuffer(memory_size=a_lot))
        agent_state.reg_loss = maybe_init(BetterLwF(with_tricks=True))
        agent_state.model = maybe_init(MyModel(...))

        # no partial init for optim and scheduler because they depend on the model
        agent_state.opt = MyCLOptimizer(model.parameters(), lr=0.001)  # could be GEM, OWM, ...
        agent_state.scheduler = MyScheduler(opt, ...)

    # this is the usual before_exp 
    updatable_objs = [agent_state.replay, agent_state.reg_loss, agent_state.model]
    [uo.pre_update(exp, agent_state) for uo in updatable_objs]

    dl = AvlDataLoader(exp.dataset, agent_state.replay.buffer, batch_size=128, ...)
    for x, y in dl:
        opt.zero_grad()
        yp = agent_state.model(x)
        l = CrossEntropyLoss(yp, y)
        l += agent_state.reg_loss()
        l.backward()
        l.step()

    # this is the usual after_exp 
    updatable_objs = [agent_state.replay, agent_state.reg_loss, agent_state.model]
    [uo.post_update(exp, agent_state) for uo in updatable_objs]
    return agent_state

@torch.no_grad
def my_eval(stream, model, metrics, **extra_args):
    # eval also becomes simpler. Notice how in Avalanche it's harder to check whether
    # we are evaluating a single exp. or the whole stream.
    # Now we evaluate each stream with a separate function call

    # calls are needed to update metric counters
    [uo.before_stream(exp) for uo in metrics]
    for exp in stream:      
        dl = AvlDataLoader(exp.dataset, agent_state.replay.buffer, batch_size=128, ...)
        for x, y in dl:
            yp = agent_state.model(x)
            [uo.update(yp, y) for uo in metrics]
        [uo.step() for uo in metrics]
    [uo.after_stream(exp) for uo in metrics]

if __name__ == '__name__':
    torch_data = MyDataset()
    avl_train_data = AvalancheDataset(torch_data)
    avl_test_data = AvalancheDataset(torch_data)
    train_stream, test_stream = split_by_class(avl_train_data, avl_test_data, num_experiences)
    ocl_stream = ocl_split(train_stream, num_experiences)

    metrics = [AvlAccuracy(), AvlCE()]

    agent_state = None
    for exp in ocl_stream:
        agent_state = my_cool_method(exp, agent_state)
        my_eval(test_stream, agent_state.model, metrics)

    acc_timeline = metrics[0].result_all()
    plt.plot(acc_timeline)

    fm = forgetting_matrix(acc_timeline)
    plt.matshow(fm)
AlbinSou commented 8 months ago

Hmm, this is not exactly what I imagined but I like the idea. So this would mean getting rid of the whole "plugin" system ?

AntonioCarta commented 8 months ago

Yes, I understand that it's not what you meant. I wrote it so that we can see both in a practical example and discuss them.

So this would mean getting rid of the whole "plugin" system ?

Not necessarily. We can keep them. But the idea is that users should be able to integrate Avalanche components without having to use avalanche strategies/plugins. Plugins are more powerful so they are still useful for as to develop more complex methods, and to integrate some things that don't fit a simple pre_update/post_update scheme (e.g. many metrics and loggers).