Open andrewssobral opened 4 months ago
This fix the error that occurs using crossvalidation with MongoDB. Using sqlite works fine, but not with MongoDB. Not tested with others databases.
Code to reproduce the error:
import multiprocessing import numpy as np import chocolate as choco # Himmelblau's function def objective_function(x, y): return (x**2 + y - 11)**2 + (x + y**2 - 7)**2 def score_himmelblau(params): x, y = params['x'], params['y'] # Chocolate minimizes the loss return objective_function(x, y) def worker(params, token, result_queue): loss = score_himmelblau(params) print(f"token: {token}, params: {params}, loss:{loss}") result_queue.put([token, loss]) def main(total_loops=2, parallel_samples=2): # conn_str = "sqlite:///himmelblau.db" # conn = choco.SQLiteConnection(url=conn_str) conn_str = "mongodb://user:pass@server:port" conn = choco.MongoDBConnection(conn_str) conn.clear() s = {"x": choco.uniform(-6, 6), "y": choco.uniform(-6, 6)} cross_validation = 2 # Number of repetitions cv = None if cross_validation > 1: cv = choco.Repeat(repetitions=cross_validation, reduce=np.mean, rep_col="_repetition_id") sampler = choco.Random(conn, s, crossvalidation=cv) for _ in range(total_loops): processes = [] result_queue = multiprocessing.Queue() print("Sampling") for _ in range(parallel_samples): token, params = sampler.next() p = multiprocessing.Process(target=worker, args=(params, token, result_queue)) processes.append(p) p.start() for p in processes: p.join() print("Updating") while not result_queue.empty(): [token, loss] = result_queue.get() sampler.update(token, loss) if __name__ == "__main__": main(total_loops=3, parallel_samples=3)
Outputs:
Sampling token: {'_repetition_id': 0, '_chocolate_id': 0}, params: {'x': 5.052463796795994, 'y': -3.7657521620123244}, loss:265.4677870603498 token: {'_repetition_id': 1, '_chocolate_id': 0}, params: {'x': 5.052463796795994, 'y': -3.7657521620123244}, loss:265.4677870603498 token: {'_repetition_id': 0, '_chocolate_id': 1}, params: {'x': -1.7702774043807228, 'y': 0.9182863001230128}, loss:111.11013186694515 Updating Sampling token: {'_repetition_id': 1, '_chocolate_id': 1}, params: {'x': -1.7702774043807228, 'y': 0.9182863001230128}, loss:111.11013186694515 token: {'_repetition_id': 0, '_chocolate_id': 2}, params: {'x': 4.478070847480515, 'y': 0.48253662844702294}, loss:96.16864084899738 token: {'_repetition_id': 1, '_chocolate_id': 2}, params: {'x': 4.478070847480515, 'y': 0.48253662844702294}, loss:96.16864084899738 Updating Sampling token: {'_repetition_id': 0, '_chocolate_id': 3}, params: {'x': -3.346171228121967, 'y': -0.007153576861073319}, loss:107.07818942458216 token: {'_repetition_id': 1, '_chocolate_id': 3}, params: {'x': -3.346171228121967, 'y': -0.007153576861073319}, loss:107.07818942458216 Traceback (most recent call last): File "/home/andrews/automl/code.py", line 60, in <module> main(total_loops=3, parallel_samples=3) File "/home/andrews/automl/code.py", line 45, in main token, params = sampler.next() File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/base.py", line 156, in next token, params = self._next(reps_token) File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/sample/random.py", line 64, in _next i = self.conn.count_results() File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/crossvalidation/repeat.py", line 59, in count_results return len(self.all_results()) File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/crossvalidation/repeat.py", line 48, in all_results losses[col] = [r[col] for r in result_group if r[col] is not None] File "/home/andrews/automl/env/lib/python3.9/site-packages/chocolate/crossvalidation/repeat.py", line 48, in <listcomp> losses[col] = [r[col] for r in result_group if r[col] is not None] KeyError: '_loss' make: *** [Makefile:33: run] Error 1
After fix:
Sampling token: {'_repetition_id': 0, '_chocolate_id': 0}, params: {'x': 5.608149406742257, 'y': 3.064894367855281}, loss:617.0409012163334 token: {'_repetition_id': 1, '_chocolate_id': 0}, params: {'x': 5.608149406742257, 'y': 3.064894367855281}, loss:617.0409012163334 WARNING:root:No loss values found for any column in result group [{'x': 0.9673457838951881, 'y': 0.75540786398794, '_repetition_id': 0, '_chocolate_id': 0}, {'x': 0.9673457838951881, 'y': 0.75540786398794, '_repetition_id': 1, '_chocolate_id': 0}] token: {'_repetition_id': 0, '_chocolate_id': 1}, params: {'x': 1.209453876737438, 'y': -4.687402668141968}, loss:464.17108499501876 Updating Sampling token: {'_repetition_id': 1, '_chocolate_id': 1}, params: {'x': 1.209453876737438, 'y': -4.687402668141968}, loss:464.17108499501876 token: {'_repetition_id': 0, '_chocolate_id': 2}, params: {'x': 4.850600876835443, 'y': 2.9368614513154974}, loss:281.10752812415893 token: {'_repetition_id': 1, '_chocolate_id': 2}, params: {'x': 4.850600876835443, 'y': 2.9368614513154974}, loss:281.10752812415893 Updating Sampling token: {'_repetition_id': 0, '_chocolate_id': 3}, params: {'x': 2.7178749937884614, 'y': 4.973698204569493}, loss:420.2805540896294 token: {'_repetition_id': 1, '_chocolate_id': 3}, params: {'x': 2.7178749937884614, 'y': 4.973698204569493}, loss:420.2805540896294 WARNING:root:No loss values found for column _loss in result group [{'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 0, '_chocolate_id': 3}, {'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 1, '_chocolate_id': 3}] WARNING:root:No loss values found for any column in result group [{'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 0, '_chocolate_id': 3}, {'x': 0.7264895828157051, 'y': 0.914474850380791, '_repetition_id': 1, '_chocolate_id': 3}] token: {'_repetition_id': 0, '_chocolate_id': 4}, params: {'x': 1.9722154031738608, 'y': -1.9901511023561707}, loss:83.9580854012591 Updating
This fix the error that occurs using crossvalidation with MongoDB. Using sqlite works fine, but not with MongoDB. Not tested with others databases.
Code to reproduce the error:
Outputs:
After fix: