google / gin-config

Gin provides a lightweight configuration framework for Python
Apache License 2.0
2.06k stars 113 forks source link

Evaluate functions with parameters in configs #178

Open HendrikSchmidt opened 1 year ago

HendrikSchmidt commented 1 year ago

We are currently working on a project that uses random search to find some hyperparameters and are looking for a solution on how to achieve that with gin. Our ideal solution would basically look like this:

import Model
import random_search

Model.learning_rate = @random_search(1, 2, 3, 4, 5)
Model.num_layers = @random_search(8, 16, 32)

Where gin would call the function random_search we have defined somewhere in our project with the given arguments.

I guess in general being able to use functions that are evaluated while parsing would be nice to have, I can see other use cases benefitting as well.

We explored different workarounds, first decorating random_search itself with @gin.configurable and having it add a macro to the config, which leads to something like this:

import Model
import random_search

random_search.learning_rate  = [1 , 2, 3, 4, 5]
Model.learning_rate = %LEARNING_RATE
random_search.num_layers  = [8, 16, 32]
Model.num_layers = %NUM_LAYERS

Which works but is just a bit more cumbersome, as every function I'd want to use like this needs to bind the macros now. Additionally, this syntax might be confusing to new users as it is unclear where the macro binding comes from.

We intermittently also preprocessed/parsed the gin files ourselves and passed the rewritten files to gin, where we allowed a syntax as in the upper example (only for one specific function, not in general), that got replaced by the evaluated function call. E.g. the line Model.learning_rate = random_search(1, 2, 3, 4, 5) became Model.learning_rate = 4 before gin parsed the contents. However, our parser was quite brittle and this didn't work for included files, as these were only parsed by gin and our syntax ofc didn't work in the gin parser.

This is why we changed to our current approach:

...
# Optimizer params
Adam.weight_decay = 1e-6
optimizer/random_search.class_to_configure = @Adam
optimizer/random_search.lr = [3e-4, 1e-4, 3e-5, 1e-5]

# Encoder params
LSTMNet.input_dim = %EMB
LSTMNet.num_classes = %NUM_CLASSES
model/random_search.class_to_configure = @LSTMNet
model/random_search.hidden_dim = [32, 64, 128, 256]
model/random_search.layer_dim = [1, 2, 3]

run_random_searches.scopes = ["model", "optimizer"]

run_random_searches.scopes defines the scopes that the random search runs in. Each scope represents a class which will get bindings with randomly searched parameters. In this example, we have the two scopes model and optimizer. For each scope a class_to_configure needs to be set to the class it represents, in this case LSTMNet and Adam respectively. We can add whichever parameter we want to the classes following this syntax:

run_random_searches.scopes = ["<scope>", ...]
<scope>/random_search.class_to_configure = @<SomeClass>
<scope>/random_search.<param> = ['list', 'of', 'possible', 'values']

The scopes take care of adding the parameters only to the pertinent classes, whereas the random_search() function actually randomly choses a value and binds it to the gin configuration.

If we want to overwrite the model configuration in a different file, this can be done easily:

include "configs/models/LSTM.gin"

Adam.lr = 1e-4

model/random_search.hidden_dim = [100, 200]

This configuration for example overwrites the lr parameter of Adam with a concrete value, while it only specifies a different search space for hidden_dim of LSTMNet to run the random search on.

The code running the random search looks like this:

@gin.configurable
def random_search(class_to_configure: type = gin.REQUIRED, **kwargs: dict[str, list]) -> list[str]:
    """Randomly searches parameters for a class and sets gin bindings.

    Args:
        class_to_configure: The class that gets configured with the parameters.
        kwargs: A dict containing the name of the parameter and a list of possible values.

    Returns:
        The randomly searched parameters.
    """
    randomly_searched_params = []
    for param, values in kwargs.items():
        param_to_set = f"{class_to_configure.__name__}.{param}"
        if f"{param_to_set}=" in gin.config_str().replace(" ", ""):
            continue  # hyperparameter is already set in the config (e.g. from experiment), so skip random search
        value = values[np.random.randint(len(values))]
        randomly_searched_params += [(param_to_set, value)]
    return randomly_searched_params

@gin.configurable
def run_random_searches(scopes: list[str] = gin.REQUIRED) -> list[str]:
    """Executes random searches for the different scopes defined in gin configs.

    Args:
        scopes: The gin scopes to explicitly set.

    Returns:
        The randomly searched parameters.
    """
    randomly_searched_params = []
    for scope in scopes:
        with gin.config_scope(scope):
            randomly_searched_params += random_search()
    return randomly_searched_params

This works fairly alright for our current setup, but natively supporting function evaluation with parameters would still be preferable.

Has there been any discussions regarding this topic that I missed or are there any counterarguments to supporting function calls? Or did I just plain miss some functionality that enables something like this already?