brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
508 stars 92 forks source link

Half precision (float16 or bfloat16) support #539

Open CloudyDory opened 10 months ago

CloudyDory commented 10 months ago

Does BrainPy fully support half-precision floating point numbers? I have tried to changed some of my own BrainPy code from using brainpy.math.float32 to brainpy.math.float16 or brainpy.math.bfloat16 (by explicitly setting the dtype of all variables and using a debugger to make sure that they won't be promoted to float32), but it seems that the GPU memory consumption and running speed is almost the same as using float32.

chaoming0625 commented 10 months ago

Great! This requirement needs to explicitly cast all parameters to brainpy.math.float_. For example, for a HH neuron model, its parameter gNa should be reinterpreted as gNa = bm.asarray(gNa, bm.float_). Ideally, users can set brainpy.math.set(float_=bm.float16), then all variables are running with float16 types.

chaoming0625 commented 10 months ago

One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_ type.

CloudyDory commented 10 months ago

One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_ type.

Could you let me know how to cast the runge kutta coefficients into brainpy.math.float_? It seems that the coefficients are automatically generated.

chaoming0625 commented 10 months ago

yes, changes should be made in the brainpy framework. Note that dt should also be cast in the integrators.

CloudyDory commented 9 months ago

Update: I think GPU memory consumption is mostly determined by JAX which preallocates 75% of the total GPU memory by default. This may be the reason why I don't see a reduction of memory consumption after switching to FP16.

chaoming0625 commented 9 months ago

The preallocation can be disabled with the setting of brainpy.math.disable_gpu_memory_preallocation().

CloudyDory commented 6 months ago

Hi, when running bm.set(float_=bm.bfloat16), I get a NotImplementedError. Is bfloat16 currently not supported in BrainPy?

chaoming0625 commented 6 months ago

It is supported, but the set operation does not recognize it. Maybe we should customize the set() function.

CloudyDory commented 6 months ago

I guess we should just add one more condition in the set_float function in brainpy/_src/math/environment.py?

def set_float(dtype: type):
  """Set global default float type.

  Parameters
  ----------
  dtype: type
    The float type.
  """
  if dtype in [jnp.float16, 'float16', 'f16']:
    defaults.__dict__['float_'] = jnp.float16
    defaults.__dict__['ti_float'] = ti.float16
  elif dtype in [jnp.float32, 'float32', 'f32']:
    defaults.__dict__['float_'] = jnp.float32
    defaults.__dict__['ti_float'] = ti.float32
  elif dtype in [jnp.float64, 'float64', 'f64']:
    defaults.__dict__['float_'] = jnp.float64
    defaults.__dict__['ti_float'] = ti.float64
  else:
    raise NotImplementedError
chaoming0625 commented 6 months ago

Yes!