timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
77 stars 14 forks source link

Add cuda/mps/xla support to sample() #18

Closed svwingerden closed 10 months ago

svwingerden commented 1 year ago
svwingerden commented 1 year ago

I think we can get basic CUDA support working before the hackathon. I'll write some quick speed benchmarks, any other benchmarks you were thinking of?

svwingerden commented 1 year ago

Basic Cuda & benchmarking done, waiting until after #15 to make a new PR

svwingerden commented 1 year ago

Partially resolved by https://github.com/timaeus-research/devinterp/pull/33, MPS / XLA support is untested AFAIK. Leaving this open until we've tested those

svwingerden commented 1 year ago

XLA works fine. MPS still untested

svwingerden commented 10 months ago

MPS works, but some functions don't exist in MPS. There are warnings where appropriate.