Open CloudyDory opened 1 year ago
Great! This requirement needs to explicitly cast all parameters to brainpy.math.float_
. For example, for a HH neuron model, its parameter gNa
should be reinterpreted as gNa = bm.asarray(gNa, bm.float_)
. Ideally, users can set brainpy.math.set(float_=bm.float16)
, then all variables are running with float16
types.
One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_
type.
One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into
brainpy.math.float_
type.
Could you let me know how to cast the runge kutta coefficients into brainpy.math.float_
? It seems that the coefficients are automatically generated.
yes, changes should be made in the brainpy framework. Note that dt
should also be cast in the integrators.
Update: I think GPU memory consumption is mostly determined by JAX which preallocates 75% of the total GPU memory by default. This may be the reason why I don't see a reduction of memory consumption after switching to FP16.
The preallocation can be disabled with the setting of brainpy.math.disable_gpu_memory_preallocation()
.
Hi, when running bm.set(float_=bm.bfloat16)
, I get a NotImplementedError
. Is bfloat16
currently not supported in BrainPy?
It is supported, but the set
operation does not recognize it. Maybe we should customize the set()
function.
I guess we should just add one more condition in the set_float
function in brainpy/_src/math/environment.py
?
def set_float(dtype: type):
"""Set global default float type.
Parameters
----------
dtype: type
The float type.
"""
if dtype in [jnp.float16, 'float16', 'f16']:
defaults.__dict__['float_'] = jnp.float16
defaults.__dict__['ti_float'] = ti.float16
elif dtype in [jnp.float32, 'float32', 'f32']:
defaults.__dict__['float_'] = jnp.float32
defaults.__dict__['ti_float'] = ti.float32
elif dtype in [jnp.float64, 'float64', 'f64']:
defaults.__dict__['float_'] = jnp.float64
defaults.__dict__['ti_float'] = ti.float64
else:
raise NotImplementedError
Yes!
Does BrainPy fully support half-precision floating point numbers? I have tried to changed some of my own BrainPy code from using
brainpy.math.float32
tobrainpy.math.float16
orbrainpy.math.bfloat16
(by explicitly setting the dtype of all variables and using a debugger to make sure that they won't be promoted tofloat32
), but it seems that the GPU memory consumption and running speed is almost the same as usingfloat32
.