To highlight just the relevant bits, I've split out two "backends", one for "default" LLC estimation, the other for LLC estimation on TPUs.
In order to use TPUs, you need to run pip install devinterp[tpu] to install the additional optional dependencies. Then, by importing from the original location (e.g., devinterp.slt.sampler) you'll automatically get the correct functions/callbacks for your backend. You can override this with the env variable USE_TPU_BACKEND=0.
To highlight just the relevant bits, I've split out two "backends", one for "default" LLC estimation, the other for LLC estimation on TPUs.
In order to use TPUs, you need to run
pip install devinterp[tpu]
to install the additional optional dependencies. Then, by importing from the original location (e.g.,devinterp.slt.sampler
) you'll automatically get the correct functions/callbacks for your backend. You can override this with the env variableUSE_TPU_BACKEND=0
.Added by Stan: I've also refactored temperature to nbeta, as this was needed for TPU integration with
aether