kristeligt-dagblad / dbt_ml

Package for dbt that allows users to train, audit and use BigQuery ML models.
Apache License 2.0
62 stars 26 forks source link

Add support for hyperparameter tuning model options #22

Closed juszzz closed 3 years ago

juszzz commented 3 years ago

BigQuery ML recently added experimental support for hyperparameter tuning of models. The syntax suggests that, for any parameter you want to tune, you specify it in the model options block via hyperparameter={HPARAM_RANGE(min, max) | HPARAM_CANDIDATES([candidates]) } directives.

Unfortunately, it seems like this doesn't work out-of-the-box with dbt_ml. For example, adding 'num_estimators': hparam_range(1, 10) to a boosted tree model configuration results in a compilation error. I think this is because Jinja is unable to find a function called hparam_range. Putting the entire thing in quotes, e.g., 'num_estimators': 'hparam_range(1, 10)' causes BigQuery to barf since it isn't expecting a string. Anyone have any thoughts on the correct way to pass these BigQuery "functions" through unchanged?

rbjerrum commented 3 years ago

Unfortunately, I don't think this is possible with the current implementation. The error is raised because all jinja string types gets quoted by the json filter (see https://github.com/kristeligt-dagblad/dbt_ml/blob/master/macros/materializations/model.sql#L32).

I definitely think we should support the hparam_* functions. A set of macros with names matching the corresponding hparam_* functions could be an elegant solution. Open for suggestions, though!

rbjerrum commented 3 years ago

@juszzz, I've added support for tuning hyperparameters on the master branch. You can now do something like the following:

{{
    config(
        materialized='model',
        ml_config={
            'model_type': 'dnn_classifier',
            'auto_class_weights': true,
            'learn_rate': dbt_ml.hparam_range(0.01, 0.1),
            'early_stop': false,
            'max_iterations': 50,
            'num_trials': 4,
            'optimizer': dbt_ml.hparam_candidates(['adam', 'sgd'])
        }
    )
}}

Can I ask you to give it a spin? You can fetch from git instead of the dbt package hub as described here: https://docs.getdbt.com/docs/building-a-dbt-project/package-management#git-packages

juszzz commented 3 years ago

@rbjerrum works like a charm! Thanks for implementing this feature request!!

rbjerrum commented 3 years ago

Great to hear it, @juszzz. I will update the README and get this out in v0.5.1.