facebookresearch / dora

Dora is an experiment management framework. It expresses grid searches as pure python files as part of your repo. It identifies experiments with a unique hash signature. Scale up to hundreds of experiments without losing your sanity.
MIT License
262 stars 24 forks source link

Support for custom resolvers with Hydra #59

Closed aRI0U closed 7 months ago

aRI0U commented 7 months ago

❓ Questions

Hi, With hydra and more generally omegaconf, it is possible to register new resolvers to apply custom functions directly within the YAML configuration. I can for example do the following:

from omegaconf import OmegaConf

def effective_lr(base_lr: float, batch_size: int) -> float:
    return base_lr * batch_size / 256

OmegaConf.register_new_resolver("effective_lr", effective_lr)

@hydra.main(...)
def main(cfg):
    ...

which enables me to put directly in my yaml file:

data:
    batch_size: 128
    ....
model:
    ...
    optimizer:
        _target_: torch.optim.Adam
        _partial_: true
        lr: ${effective_lr:${model.base_lr},${data.batch_size}}
    base_lr: 0.0003
    ...

However, when using Dora's hydra_main, this doesn't work anymore. Indeed when I use e.g. dora run, the first thing executed is main.get_xp and for some reason this function under the hood resolves the whole config, thus raising an InterpolationError since the custom resolver hasn't been registered yet.

The only workaround I found to resolve this issue consists in directly overriding hydra_main:

from dora import hydra_main

def my_hydra_main(config_name: str, config_path: str, extra_resolvers: Dict[str, Callable] = None, **kwargs):
    """Wrap your main function with this.
    You can pass extra kwargs, e.g. `version_base` introduced in 1.2.
    """
    extra_resolvers = extra_resolvers or {}
    for name, resolver in extra_resolvers.items():
        OmegaConf.register_new_resolver(name, resolver)
    return hydra_main(config_name=config_name, config_path=config_path, **kwargs)

@my_hydra_main(version_base="1.3",
            config_path="../configs",
            config_name="train.yaml",
            extra_resolvers={"effective_lr": effective_lr})
def main(cfg: DictConfig) -> Optional[float]:
    ...

This works quite well, however I'd have 3 questions:

  1. Is it necessary to parse the config directly when calling main.get_xp() in Dora, since anyway the cfg arg of the DecoratedMain is not resolved? If yes, why?
  2. Is there already a way to register custom resolvers while using Dora that I may have missed?
  3. If no, should I consider doing a PR to replace hydra_main by my_hydra_main in next version? Since it only adds an optional arg that is not used by the original hydra_main, hopefully it shouldn't break anything.

Thanks a lot!

adefossez commented 7 months ago

I'm wondering why not do it outside of my_hydra_main ? it doesn't seem like you are using any of the Hydra API to register new resolver, so you could have something like

def register_my_resolver():
    for name, resolver in extra_resolvers.items():
        OmegaConf.register_new_resolver(name, resolver)

@hydra_main(...)
def main(cfg):
    ...

register_my_resolver()

Note that the position of the call to register_my_resolver() doesn't actually matter, and could be before main(). I would tend to favor this option as this makes it clear from the code, that the registration of the resolvers will leak to any other call to OmegaConf, and I would tend to think a PR is not needed for that use case. Potentially an update to the README to show this pattern.

Regarding question 1., get_xp() needs to resolve the config because this is how the signature is actually derived. It will resolve the full config, compare it with the base config and build a delta. The hash of the delta is then the signature.

aRI0U commented 7 months ago

Indeed your solution works.

I initially tried something similar but since I put register_new_resolver() inside my if __name__ == '__main__': block it wasn't executed, which made me (wrongly) think that when using the dora command, only the piece of code inside hydra_main was executed.

Maybe mentioning this point in the README could be helpful, insisting on the fact that the register_new_resolver should be outside of the if __name__ == '__main__', which is a bit unusual.

Thanks a lot!