JasonGross / guarantees-based-mechanistic-interpretability

MIT License
8 stars 2 forks source link

Speed up `ein.array`, add `use_cache=False` #147

Closed JasonGross closed 3 months ago

JasonGross commented 3 months ago

ein.array now accepts use_cache=False to avoid calling the slow lambda_hash in cases where we might want to avoid it.

Additionally, instead of eagerly applying the function twice to get the target device, we speed things up by up to a factor of 2 in the common case where either the default device is correct or device= is explicitly passed. We catch RuntimeError and parse the message for same device to know when to try correcting for failure, and emit a warning in such cases.

cc @euanong @LouisYRYJ

euanong commented 3 months ago

ein is now being developed in eintorch (private repo but available @ pip install eintorch) -- shall I try to port gbmi codebase to use eintorch instead?

eintorch no longer uses lambda_hash (I implemented a fix with context managers)

JasonGross commented 3 months ago

shall I try to port gbmi codebase to use eintorch instead?

That'd be useful if it's not too much work, but not super high priority

JasonGross commented 3 months ago

@euanong If you invite me to the repo, I can make a PR for it that implements the same speedup that I implemented here (though I haven't tested how much of a speedup it is --- maybe it's not enough to be worth it?), where passing size= explicitly avoids running the function twice unless the default device does not work.