AIworx-Labs / chocolate

A fully decentralized hyperparameter optimization framework
http://chocolate.readthedocs.io
BSD 3-Clause "New" or "Revised" License
121 stars 41 forks source link

Fix KeyError: '_loss' while using crossvalidation with mongodb #42

Open andrewssobral opened 3 months ago

andrewssobral commented 3 months ago

This fix the error that occurs using crossvalidation with MongoDB. Using sqlite works fine, but not with MongoDB. Not tested with others databases.

Code to reproduce the error:

import multiprocessing
import numpy as np
import chocolate as choco

# Himmelblau's function
def objective_function(x, y):
    return (x**2 + y - 11)**2 + (x + y**2 - 7)**2

def score_himmelblau(params):
    x, y = params['x'], params['y']
    # Chocolate minimizes the loss
    return objective_function(x, y)

def worker(params, token, result_queue):
    loss = score_himmelblau(params)
    print(f"token: {token}, params: {params}, loss:{loss}")
    result_queue.put([token, loss])

def main(total_loops=2, parallel_samples=2):
    # conn_str = "sqlite:///himmelblau.db"
    # conn = choco.SQLiteConnection(url=conn_str)
    conn_str = "mongodb://user:pass@server:port"
    conn = choco.MongoDBConnection(conn_str)
    conn.clear()
    s = {"x": choco.uniform(-6, 6),
         "y": choco.uniform(-6, 6)}

    cross_validation = 2 # Number of repetitions
    cv = None
    if cross_validation > 1:
        cv = choco.Repeat(repetitions=cross_validation, reduce=np.mean, rep_col="_repetition_id")

    sampler = choco.Random(conn, s, crossvalidation=cv)

    for _ in range(total_loops):
        processes = []
        result_queue = multiprocessing.Queue()

        print("Sampling")
        for _ in range(parallel_samples):
            token, params = sampler.next()
            p = multiprocessing.Process(target=worker, args=(params, token, result_queue))
            processes.append(p)
            p.start()

        for p in processes:
            p.join()

        print("Updating")
        while not result_queue.empty():
            [token, loss] = result_queue.get()
            sampler.update(token, loss)

if __name__ == "__main__":
    main(total_loops=3, parallel_samples=3)

Outputs:

Sampling
token: {'_repetition_id': 0, '_chocolate_id': 0}, params: {'x': 5.052463796795994, 'y': -3.7657521620123244}, loss:265.4677870603498
token: {'_repetition_id': 1, '_chocolate_id': 0}, params: {'x': 5.052463796795994, 'y': -3.7657521620123244}, loss:265.4677870603498
token: {'_repetition_id': 0, '_chocolate_id': 1}, params: {'x': -1.7702774043807228, 'y': 0.9182863001230128}, loss:111.11013186694515
Updating
Sampling
token: {'_repetition_id': 1, '_chocolate_id': 1}, params: {'x': -1.7702774043807228, 'y': 0.9182863001230128}, loss:111.11013186694515
token: {'_repetition_id': 0, '_chocolate_id': 2}, params: {'x': 4.478070847480515, 'y': 0.48253662844702294}, loss:96.16864084899738
token: {'_repetition_id': 1, '_chocolate_id': 2}, params: {'x': 4.478070847480515, 'y': 0.48253662844702294}, loss:96.16864084899738
Updating
Sampling
token: {'_repetition_id': 0, '_chocolate_id': 3}, params: {'x': -3.346171228121967, 'y': -0.007153576861073319}, loss:107.07818942458216
token: {'_repetition_id': 1, '_chocolate_id': 3}, params: {'x': -3.346171228121967, 'y': -0.007153576861073319}, loss:107.07818942458216
Traceback (most recent call last):
  File "/home/andrews/automl/code.py", line 60, in <module>
    main(total_loops=3, parallel_samples=3)
  File "/home/andrews/automl/code.py", line 45, in main
    token, params = sampler.next()
  File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/base.py", line 156, in next
    token, params = self._next(reps_token)
  File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/sample/random.py", line 64, in _next
    i = self.conn.count_results()
  File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/crossvalidation/repeat.py", line 59, in count_results
    return len(self.all_results())
  File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/crossvalidation/repeat.py", line 48, in all_results
    losses[col] = [r[col] for r in result_group if r[col] is not None]
  File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/crossvalidation/repeat.py", line 48, in <listcomp>
    losses[col] = [r[col] for r in result_group if r[col] is not None]
KeyError: '_loss'
make: *** [Makefile:33: run] Error 1

After fix:

Sampling
token: {'_repetition_id': 0, '_chocolate_id': 0}, params: {'x': 5.608149406742257, 'y': 3.064894367855281}, loss:617.0409012163334
token: {'_repetition_id': 1, '_chocolate_id': 0}, params: {'x': 5.608149406742257, 'y': 3.064894367855281}, loss:617.0409012163334
WARNING:root:No loss values found for any column in result group [{'x': 0.9673457838951881, 'y': 0.75540786398794, '_repetition_id': 0, '_chocolate_id': 0}, {'x': 0.9673457838951881, 'y': 0.75540786398794, '_repetition_id': 1, '_chocolate_id': 0}]
token: {'_repetition_id': 0, '_chocolate_id': 1}, params: {'x': 1.209453876737438, 'y': -4.687402668141968}, loss:464.17108499501876
Updating
Sampling
token: {'_repetition_id': 1, '_chocolate_id': 1}, params: {'x': 1.209453876737438, 'y': -4.687402668141968}, loss:464.17108499501876
token: {'_repetition_id': 0, '_chocolate_id': 2}, params: {'x': 4.850600876835443, 'y': 2.9368614513154974}, loss:281.10752812415893
token: {'_repetition_id': 1, '_chocolate_id': 2}, params: {'x': 4.850600876835443, 'y': 2.9368614513154974}, loss:281.10752812415893
Updating
Sampling
token: {'_repetition_id': 0, '_chocolate_id': 3}, params: {'x': 2.7178749937884614, 'y': 4.973698204569493}, loss:420.2805540896294
token: {'_repetition_id': 1, '_chocolate_id': 3}, params: {'x': 2.7178749937884614, 'y': 4.973698204569493}, loss:420.2805540896294
WARNING:root:No loss values found for column _loss in result group [{'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 0, '_chocolate_id': 3}, {'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 1, '_chocolate_id': 3}]
WARNING:root:No loss values found for any column in result group [{'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 0, '_chocolate_id': 3}, {'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 1, '_chocolate_id': 3}]
token: {'_repetition_id': 0, '_chocolate_id': 4}, params: {'x': 1.9722154031738608, 'y': -1.9901511023561707}, loss:83.9580854012591
Updating