facebookresearch / CompilerGym

Reinforcement learning environments for compiler and program optimization tasks
https://compilergym.ai/
MIT License
906 stars 127 forks source link

[Bug] NotImplementedError when using env.fork() with SynchronousSqliteLogger #744

Open ricardoprins opened 2 years ago

ricardoprins commented 2 years ago

The code below was executed from a Google Colab Notebook:

db_path = Path("/content/db/test.db")

env = SynchronousSqliteLogger(
    env=gym.make('llvm-ir-ic-v0'),
    db_path=db_path,
)

from time import time

def greedy(env, search_time_seconds: int, **kwargs) -> None:

    def eval_action(env, action: int):
        with env.fork() as fkd:
            return (fkd.step(action)[1], action)

    end_time = time() + search_time_seconds
    while time() < end_time:
        best = max(eval_action(env, action) for action in range(env.action_space.n))
        if best[0] <= 0 or env.step(best[1])[2]:
            return

greedy(env, 300)

and here's the traceback:

---------------------------------------------------------------------------

NotImplementedError                       Traceback (most recent call last)

[<ipython-input-3-5160ef3145bf>](https://localhost:8080/#) in <module>()
     21             return
     22 
---> 23 greedy(env, 300)

3 frames

[<ipython-input-3-5160ef3145bf>](https://localhost:8080/#) in greedy(env, search_time_seconds, **kwargs)
     17     end_time = time() + search_time_seconds
     18     while time() < end_time:
---> 19         best = max(eval_action(env, action) for action in range(env.action_space.n))
     20         if best[0] <= 0 or env.step(best[1])[2]:
     21             return

[<ipython-input-3-5160ef3145bf>](https://localhost:8080/#) in <genexpr>(.0)
     17     end_time = time() + search_time_seconds
     18     while time() < end_time:
---> 19         best = max(eval_action(env, action) for action in range(env.action_space.n))
     20         if best[0] <= 0 or env.step(best[1])[2]:
     21             return

[<ipython-input-3-5160ef3145bf>](https://localhost:8080/#) in eval_action(env, action)
     12 
     13     def eval_action(env, action: int):
---> 14         with env.fork() as fkd:
     15             return (fkd.step(action)[1], action)
     16 

[/usr/local/lib/python3.7/dist-packages/compiler_gym/wrappers/sqlite_logger.py](https://localhost:8080/#) in fork(self)
    272 
    273     def fork(self):
--> 274         raise NotImplementedError

NotImplementedError:
ChrisCummins commented 2 years ago

Thanks for the report @ricardoprins. I can see that the problem is the SqliteLogger does not yet support fork():

https://github.com/facebookresearch/CompilerGym/blob/development/compiler_gym/wrappers/sqlite_logger.py#L273-L274

Would you like to have a go at implementing it?

Cheers, Chris

ricardoprins commented 2 years ago

Sure, I'll try!

ricardoprins commented 2 years ago

Would simply doing this work, or are there any problems I'm unaware of?

The fork() method is already implemented in the CompilerEnvWrapper class, which the logger inherits from.

    def fork(self):
        return super().fork()
ChrisCummins commented 2 years ago

Would simply doing this work, or are there any problems I'm unaware of?

The fork() method is already implemented in the CompilerEnvWrapper class, which the logger inherits from.

    def fork(self):
        return super().fork()

Not quite. You would need to make sure to construct a new SliteLoggerWrapper class and pass in the right constructor arguments so that the forked environment writes to the same database.

Cheers, Chris