allenai / allenact

An open source framework for research in Embodied-AI from AI2.
https://www.allenact.org
Other
308 stars 49 forks source link

Add Callback Support #339

Closed mattdeitke closed 2 years ago

mattdeitke commented 2 years ago

Background

Adds initial support for Callbacks, inspired by PyTorch Lightning.

The immediate use case is to enable logging during training with Weights and Biases.

Motivation

The motivation is to make it easier to log, debug, and inspect the training setup without having to manually modify runner.py.

Down the line, I suspect callbacks will also be the best place to write tests, where the tests may be in callback functions like on_checkpoint_load(model).

Example

An example usage might be to define a Callback class under the file training/callbacks/wandb_logging.py:

from typing import Any, Dict, Optional

import wandb
from allenact.base_abstractions.callbacks import Callback

class WandbLogging(Callback):
    def setup(self, name: str, **kwargs) -> None:
        wandb.init(
            project="test-project",
            entity="prior-ai2",
            name=name,
            config=kwargs,
        )

    def on_train_log(self, metric_means: Dict[str, float], step: int, **kwargs) -> None:
        wandb.log({**metric_means, "step": step})

    def on_valid_log(
        self,
        metrics: Optional[Dict[str, Any]],
        metric_means: Dict[str, float],
        step: int,
        **kwargs
    ) -> None:
        wandb.log({**metric_means, "step": step})

    def on_test_log(
        self,
        checkpoint: str,
        metrics: Dict[str, Any],
        metric_means: Dict[str, float],
        step: int,
        **kwargs
    ) -> None:
        wandb.log({**metric_means, "step": step})

and to use it, one would add the file to the --callbacks flag in the allenact command:

allenact <...> --callbacks training/callbacks/wandb_logging.py

Note that this doesn't require modifying the experiment configs at all, and hence is fully opt in functionality.

Notes

I'm still thinking about what callbacks would be best, and what should be passed into each of them.

Right now, I think the best approach I have for logging videos, images, or other more complex information, is to save that information to disk, and then process, log, and delete it inside of on_train_log(), but perhaps there's a cleaner solution.

lgtm-com[bot] commented 2 years ago

This pull request introduces 2 alerts when merging bc49d476c8ef69528c550b4c14a72fe52223bd44 into cc0d1231e57d2850225e47edd62bbbccc253f8b9 - view on LGTM.com

new alerts: