Closed MarcCote closed 2 years ago
I'm using the following script to benchmark:
import ctypes
import time
import sys
import jericho
import glob
from tqdm import tqdm
def run(rom, optimize=True, parallel=False):
env = jericho.FrotzEnv(rom)
walkthrough = env.get_walkthrough()
start = time.time()
obs = env.reset()
valid_per_step = []
for act in tqdm(walkthrough[:25]):
valid_acts = env.get_valid_actions(use_ctypes=optimize, use_parallel=parallel)
valid_per_step.append(valid_acts)
obs, rew, done, info = env.step(act)
elapsed = time.time() - start
env.close()
return elapsed, valid_per_step
def compare(rom):
t1, gt_valids = run(rom, optimize=False)
t2, pred_valids = run(rom, optimize=True)
t3, pred_valids_para = run(rom, optimize=False, parallel=True)
t4, pred_valids_para_c = run(rom, optimize=True, parallel=True)
assert len(gt_valids) == len(pred_valids)
# assert len(gt_valids) == len(pred_valids_para)
# assert len(gt_valids) == len(pred_valids_para_c)
for idx, (v1, v2) in enumerate(zip(gt_valids, pred_valids)):
false_negs = set(v1) - set(v2)
false_pos = set(v2) - set(v1)
for a in false_negs:
print(f"Act {a} was valid in Python but not in CTypes")
for a in false_pos:
print(f"Act {a} was valid in Ctypes but not in Python")
# if v1 != v2:
# print(f"Erroneous valid actions in rom {rom} at step {idx}: GT {v1} Pred {v2}")
for idx, (v1, v2) in enumerate(zip(gt_valids, pred_valids_para)):
false_negs = set(v1) - set(v2)
false_pos = set(v2) - set(v1)
for a in false_negs:
print(f"Act {a} was valid in Python but not in Para")
for a in false_pos:
print(f"Act {a} was valid in Para but not in Python")
# if v1 != v2:
# print(f"Erroneous valid actions in rom {rom} at step {idx}: GT {v1} Pred {v2}")
for idx, (v1, v2) in enumerate(zip(gt_valids, pred_valids_para_c)):
false_negs = set(v1) - set(v2)
false_pos = set(v2) - set(v1)
for a in false_negs:
print(f"Act {a} was valid in Python but not in ParaC")
for a in false_pos:
print(f"Act {a} was valid in ParaC but not in Python")
# if v1 != v2:
# print(f"Erroneous valid actions in rom {rom} at step {idx}: GT {v1} Pred {v2}")
speedup = 100 * (t1 / t2 - 1)
print(f"Rom {rom} Python {t1:.1f} Ctypes {t2:.1f} Speedup {speedup:.1f}%")
speedup = 100 * (t1 / t3 - 1)
print(f"Rom {rom} Python {t1:.1f} Para {t3:.1f} Speedup {speedup:.1f}%")
speedup = 100 * (t1 / t4 - 1)
print(f"Rom {rom} Python {t1:.1f} ParaC {t4:.1f} Speedup {speedup:.1f}%")
def check_correctness(rom):
env = jericho.FrotzEnv(rom)
walkthrough = env.get_walkthrough()
start = time.time()
obs = env.reset()
for idx, act in tqdm(enumerate(walkthrough[:25])):
valid_acts = env.get_valid_actions(use_ctypes=False)
valid_acts_c = env.get_valid_actions(use_ctypes=True)
valid_acts_para = env.get_valid_actions(use_ctypes=True, use_parallel=True)
if valid_acts != valid_acts_c:
false_negs = set(valid_acts) - set(valid_acts_c)
for a in false_negs:
print(f"Step {idx} Act {a} was valid in Python but not in CTypes")
false_pos = set(valid_acts_c) - set(valid_acts)
for a in false_pos:
print(f"Step {idx} Act {a} was valid in Ctypes but not in Python")
state = env.get_state()
print(env.step(a))
env.set_state(state)
if valid_acts != valid_acts_para:
false_negs = set(valid_acts) - set(valid_acts_para)
for a in false_negs:
print(f"Step {idx} Act {a} was valid in Python but not in Para")
false_pos = set(valid_acts_para) - set(valid_acts)
for a in false_pos:
print(f"Step {idx} Act {a} was valid in Para but not in Python")
state = env.get_state()
print(env.step(a))
env.set_state(state)
obs, rew, done, info = env.step(act)
elapsed = time.time() - start
env.close()
return elapsed
if __name__ == "__main__":
# check_correctness(sys.argv[1])
compare(sys.argv[1])
Hey Marc, thanks for this PR - I think we should definitely include it, after any necessary cleanup/refactoring. I've double checked the benchmark scripts and am getting similar speedups on my personal machine.
Great. Let me clean it and I'll ping you later.
Alright, I'm happy with the refactoring for now.
Do not merge. I made this PR so we can discuss it. There's still some cleaning/refactoring that needs to be done.