microsoft / jericho

A learning environment for man-made Interactive Fiction games.
GNU General Public License v2.0
253 stars 42 forks source link

Add parallelization in get_valid_actions #51

Closed MarcCote closed 2 years ago

MarcCote commented 2 years ago

Do not merge. I made this PR so we can discuss it. There's still some cleaning/refactoring that needs to be done.

MarcCote commented 2 years ago

I'm using the following script to benchmark:

import ctypes
import time
import sys
import jericho
import glob
from tqdm import tqdm

def run(rom, optimize=True, parallel=False):
    env = jericho.FrotzEnv(rom)
    walkthrough = env.get_walkthrough()
    start = time.time()
    obs = env.reset()
    valid_per_step = []
    for act in tqdm(walkthrough[:25]):
        valid_acts = env.get_valid_actions(use_ctypes=optimize, use_parallel=parallel)
        valid_per_step.append(valid_acts)
        obs, rew, done, info = env.step(act)
    elapsed = time.time() - start
    env.close()
    return elapsed, valid_per_step

def compare(rom):
    t1, gt_valids = run(rom, optimize=False)
    t2, pred_valids = run(rom, optimize=True)
    t3, pred_valids_para = run(rom, optimize=False, parallel=True)
    t4, pred_valids_para_c = run(rom, optimize=True, parallel=True)
    assert len(gt_valids) == len(pred_valids)
    # assert len(gt_valids) == len(pred_valids_para)
    # assert len(gt_valids) == len(pred_valids_para_c)
    for idx, (v1, v2) in enumerate(zip(gt_valids, pred_valids)):
        false_negs = set(v1) - set(v2)
        false_pos = set(v2) - set(v1)
        for a in false_negs:
            print(f"Act {a} was valid in Python but not in CTypes")
        for a in false_pos:
            print(f"Act {a} was valid in Ctypes but not in Python")
        # if v1 != v2:
        #     print(f"Erroneous valid actions in rom {rom} at step {idx}: GT {v1} Pred {v2}")
    for idx, (v1, v2) in enumerate(zip(gt_valids, pred_valids_para)):
        false_negs = set(v1) - set(v2)
        false_pos = set(v2) - set(v1)
        for a in false_negs:
            print(f"Act {a} was valid in Python but not in Para")
        for a in false_pos:
            print(f"Act {a} was valid in Para but not in Python")
        # if v1 != v2:
        #     print(f"Erroneous valid actions in rom {rom} at step {idx}: GT {v1} Pred {v2}")
    for idx, (v1, v2) in enumerate(zip(gt_valids, pred_valids_para_c)):
        false_negs = set(v1) - set(v2)
        false_pos = set(v2) - set(v1)
        for a in false_negs:
            print(f"Act {a} was valid in Python but not in ParaC")
        for a in false_pos:
            print(f"Act {a} was valid in ParaC but not in Python")
        # if v1 != v2:
        #     print(f"Erroneous valid actions in rom {rom} at step {idx}: GT {v1} Pred {v2}")

    speedup = 100 * (t1 / t2 - 1)
    print(f"Rom {rom} Python {t1:.1f} Ctypes {t2:.1f} Speedup {speedup:.1f}%")
    speedup = 100 * (t1 / t3 - 1)
    print(f"Rom {rom} Python {t1:.1f} Para {t3:.1f} Speedup {speedup:.1f}%")
    speedup = 100 * (t1 / t4 - 1)
    print(f"Rom {rom} Python {t1:.1f} ParaC {t4:.1f} Speedup {speedup:.1f}%")

def check_correctness(rom):
    env = jericho.FrotzEnv(rom)
    walkthrough = env.get_walkthrough()
    start = time.time()
    obs = env.reset()
    for idx, act in tqdm(enumerate(walkthrough[:25])):
        valid_acts = env.get_valid_actions(use_ctypes=False)
        valid_acts_c = env.get_valid_actions(use_ctypes=True)
        valid_acts_para = env.get_valid_actions(use_ctypes=True, use_parallel=True)
        if valid_acts != valid_acts_c:
            false_negs = set(valid_acts) - set(valid_acts_c)
            for a in false_negs:
                print(f"Step {idx} Act {a} was valid in Python but not in CTypes")
            false_pos = set(valid_acts_c) - set(valid_acts)
            for a in false_pos:
                print(f"Step {idx} Act {a} was valid in Ctypes but not in Python")
                state = env.get_state()
                print(env.step(a))
                env.set_state(state)
        if valid_acts != valid_acts_para:
            false_negs = set(valid_acts) - set(valid_acts_para)
            for a in false_negs:
                print(f"Step {idx} Act {a} was valid in Python but not in Para")
            false_pos = set(valid_acts_para) - set(valid_acts)
            for a in false_pos:
                print(f"Step {idx} Act {a} was valid in Para but not in Python")
                state = env.get_state()
                print(env.step(a))
                env.set_state(state)

        obs, rew, done, info = env.step(act)

    elapsed = time.time() - start
    env.close()
    return elapsed

if __name__ == "__main__":
    # check_correctness(sys.argv[1])
    compare(sys.argv[1])
mhauskn commented 2 years ago

Hey Marc, thanks for this PR - I think we should definitely include it, after any necessary cleanup/refactoring. I've double checked the benchmark scripts and am getting similar speedups on my personal machine.

MarcCote commented 2 years ago

Great. Let me clean it and I'll ping you later.

MarcCote commented 2 years ago

Alright, I'm happy with the refactoring for now.