timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
77 stars 14 forks source link

Add Basic Support for TPUs, refactor temp -> nbeta #81

Closed jqhoogland closed 3 months ago

jqhoogland commented 5 months ago

To highlight just the relevant bits, I've split out two "backends", one for "default" LLC estimation, the other for LLC estimation on TPUs.

In order to use TPUs, you need to run pip install devinterp[tpu] to install the additional optional dependencies. Then, by importing from the original location (e.g., devinterp.slt.sampler) you'll automatically get the correct functions/callbacks for your backend. You can override this with the env variable USE_TPU_BACKEND=0.

├── devinterp
│   ├── backends
│   │   ├── default
│   │   │   └── slt
│   │   │       └── sampler.py
│   │   └── tpu
│   │       └── slt
│   │           └── sampler.py
│   ├── slt
│   │   ├── callback.py
│   │   ├── cov.py
│   │   ├── gradient.py
│   │   ├── llc.py
│   │   ├── loss.py
│   │   ├── mala.py
│   │   ├── norms.py
│   │   ├── sampler.py
│   │   ├── trace.py
│   │   └── wbic.py

Added by Stan: I've also refactored temperature to nbeta, as this was needed for TPU integration with aether