c-bata / goptuna

A hyperparameter optimization framework, inspired by Optuna.
https://pkg.go.dev/github.com/c-bata/goptuna
MIT License
258 stars 22 forks source link
bandit-algorithms bayesian-optimization blackbox-optimization evolution-strategies

Goptuna

Software License GoDoc Go Report Card

Decentralized hyperparameter optimization framework, inspired by Optuna [1]. This library is particularly designed for machine learning, but everything will be able to optimize if you can define the objective function (e.g. Optimizing the number of goroutines of your server and the memory buffer size of the caching systems).

Supported algorithms:

Goptuna supports various state-of-the-art Bayesian optimization, evolution strategies and Multi-armed bandit algorithms. All algorithms are implemented in pure Go and continuously benchmarked on GitHub Actions.

Projects using Goptuna:

Installation

You can integrate Goptuna in wide variety of Go projects because of its portability of pure Go.

$ go get -u github.com/c-bata/goptuna

Usage

Goptuna supports Define-by-Run style API like Optuna. You can dynamically construct the search spaces.

Basic usage

package main

import (
    "log"
    "math"

    "github.com/c-bata/goptuna"
    "github.com/c-bata/goptuna/tpe"
)

// ① Define an objective function which returns a value you want to minimize.
func objective(trial goptuna.Trial) (float64, error) {
    // ② Define the search space via Suggest APIs.
    x1, _ := trial.SuggestFloat("x1", -10, 10)
    x2, _ := trial.SuggestFloat("x2", -10, 10)
    return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

func main() {
    // ③ Create a study which manages each experiment.
    study, err := goptuna.CreateStudy(
        "goptuna-example",
        goptuna.StudyOptionSampler(tpe.NewSampler()))
    if err != nil { ... }

    // ④ Evaluate your objective function.
    err = study.Optimize(objective, 100)
    if err != nil { ... }

    // ⑤ Print the best evaluation parameters.
    v, _ := study.GetBestValue()
    p, _ := study.GetBestParams()
    log.Printf("Best value=%f (x1=%f, x2=%f)",
        v, p["x1"].(float64), p["x2"].(float64))
}

Link: Go Playground

Furthermore, I recommend you to use RDB storage backend for following purposes.

Built-in Web Dashboard

You can check optimization results by built-in web dashboard.

Manage optimization results Interactive live-updating graphs
state-of-the-art-algorithms visualization

Advanced Usage

Parallel optimization with multiple goroutine workers ``Optimize`` method of ``goptuna.Study`` object is designed as the goroutine safe. So you can easily optimize your objective function using multiple goroutine workers. ```go package main import ... func main() { study, _ := goptuna.CreateStudy(...) eg, ctx := errgroup.WithContext(context.Background()) study.WithContext(ctx) for i := 0; i < 5; i++ { eg.Go(func() error { return study.Optimize(objective, 100) }) } if err := eg.Wait(); err != nil { ... } ... } ``` [full source code](./_examples/concurrency/main.go)
Distributed optimization using MySQL There is no complicated setup to use RDB storage backend. First, setup MySQL server like following to share the optimization result. ```console $ docker pull mysql:8.0 $ docker run \ -d \ --rm \ -p 3306:3306 \ -e MYSQL_USER=goptuna \ -e MYSQL_DATABASE=goptuna \ -e MYSQL_PASSWORD=password \ -e MYSQL_ALLOW_EMPTY_PASSWORD=yes \ --name goptuna-mysql \ mysql:8.0 ``` Then, create a study object using Goptuna CLI. ```console $ goptuna create-study --storage mysql://goptuna:password@localhost:3306/yourdb --study yourstudy yourstudy ``` ```mysql $ mysql --host 127.0.0.1 --port 3306 --user goptuna -ppassword -e "SELECT * FROM studies;" +----------+------------+-----------+ | study_id | study_name | direction | +----------+------------+-----------+ | 1 | yourstudy | MINIMIZE | +----------+------------+-----------+ 1 row in set (0.00 sec) ``` Finally, run the Goptuna workers which contains following code. You can execute distributed optimization by just executing this script from multiple server instances. ```go package main import ... func main() { db, _ := gorm.Open(mysql.Open("goptuna:password@tcp(localhost:3306)/yourdb?parseTime=true"), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) storage := rdb.NewStorage(db) defer db.Close() study, _ := goptuna.LoadStudy( "yourstudy", goptuna.StudyOptionStorage(storage), ..., ) _ = study.Optimize(objective, 50) ... } ``` Full source code is available [here](./_examples/simple_rdb/main.go).
Receive notifications of each trials You can receive notifications of each trials via channel. It can be used for logging and any notification systems. ```go package main import ... func main() { trialchan := make(chan goptuna.FrozenTrial, 8) study, _ := goptuna.CreateStudy( ... goptuna.StudyOptionIgnoreObjectiveErr(true), goptuna.StudyOptionSetTrialNotifyChannel(trialchan), ) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() err = study.Optimize(objective, 100) close(trialchan) }() go func() { defer wg.Done() for t := range trialchan { log.Println("trial", t) } }() wg.Wait() if err != nil { ... } ... } ``` [full source code](./_examples/trialnotify/main.go)

Links

References:

Presentations:

Blog posts:

Status:

License

This software is licensed under the MIT license, see LICENSE for more information.