RatInABox-Lab / RatInABox

A python package for modelling locomotion in complex environments and spatially/velocity selective cell activity.
MIT License
175 stars 31 forks source link

Multiplatform support with jax for GPU #20

Closed TomGeorge1234 closed 1 year ago

TomGeorge1234 commented 1 year ago

Could RIAB support GPU via jax (import jax.numpy as jnp)

This should be backward compatible. I.e. a users should optionally specify a GPU usage flag otherwise numpy is used as normal.

GPU would not massively speed up RIAB except for FeedForwardCells and any use-case where synaptic weight matrices into FeedForwardCells are learnt, typically requiring large N_cell x N_cells matrix multiplications. I.e it is likely that non-GPU will continue to satisfy a large majority of use cases.

For now I do not intend to use jax to pre-compile or vectorise RIAB code, a change which would require significant and likely backwards incompatible modifications to the code base.

TomGeorge1234 commented 1 year ago

update: still working on this, progress has been made, eta. probably another month.

TomGeorge1234 commented 1 year ago

As mentioned in #60 this is feature which is being considered for 2.0