google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.77k stars 2.72k forks source link

Wider support for FP8 datatypes #16673

Closed kmheckel closed 10 months ago

kmheckel commented 1 year ago

@jakevdp I am trying to use float8 datatypes as part of training quantized spiking neural networks but cannot initialise weights for a network because jax.random._uniform does not support jnp.float8_* dtypes. My intent is to train on an H100 GPU which supports fp8 with double the throughput of fp16. The error message for jax.random._uniform seems to be a little out of date as well as it only discusses fp32/64 while also accepting fp16.

I'm not sure how much work it would take to permit fp8 types but would greatly appreciate it! It seems that the uniform sampling method would need updated but I'm not sure how many other changes would be required to support fp8. Beyond language model applications, neuroevolution would largely benefit from fp8 because of the fact that you could double the population size for the same compute/memory capacity and SNN research would also benefit as one could represent both the input spikes and model weights in fp8 and eliminate having to recast int8 inputs to fp16 values for use.

Originally posted by @kmheckel in https://github.com/google/jax/discussions/16342#discussioncomment-6396762

jakevdp commented 1 year ago

Hi - thanks for the request! I think this is a fundamentally difficult question, for the same reasons as those discussed in https://github.com/google/jax/discussions/13798#discussioncomment-4499272. The issue is that typical algorithms for generating uniform deviates assume that the set of floating point numbers has similar behavior to the set of real numbers, but as you reduce the bit width this becomes less and less true. For example, if we were generating uniform deviates in float8, we will quickly run into resolution limits when trying to select values between 0.0 and 1.0. Here's some quick code to visualize which values are actually possible to represent:

import ml_dtypes
import numpy as np
import matplotlib.pyplot as plt

x = np.arange(256, dtype='uint8').view(ml_dtypes.float8_e5m2).astype(float)
x_in_range = x[(0 <= x) & (x < 1)]
plt.hist(x_in_range, bins=100);

download-1

Again, this shows all the values that are possible to represent in float8_e5m2, and this makes clear that it wouldn't be entirely straightforward to generate uniformly-distributed values in this range, since it's really fundamentally a discrete distribution. What do you think? Do you know of any existing work in this area?

kmheckel commented 1 year ago

@jakevdp , That's a good point - I'll have to do some research as I don't know a good method for uniformly sampling floats for 8 bits. I've tried looking at the Brevitas and bitsandbytes packages but haven't found anything there. Based on the other issue you linked perhaps uniformly sampling in a higher precision and casting down to FP8 might be a decent hack?

My other thought would be some kind of inverse frequency weighting or using a lookup table since the 8 bit space is relatively small?

In the meantime I think I might explore binary initialisation for the networks since this is a tough question.

jakevdp commented 1 year ago

My other thought would be some kind of inverse frequency weighting or using a lookup table since the 8 bit space is relatively small?

Inverse frequency weighting requires floating-point math, so you'd probably have to do it in higher precision, at which point you've given up on the goal of generating float8 directly.

I think generating at high precision and then casting is a useful solution, if that's suitable for the use-case. But that's easy enough to do already with something like random.uniform(key, size).astype('float8'). It's not clear to me that we should default to that approach if you pass float8 to uniform.

kmheckel commented 1 year ago

I think generating at high precision and then casting is a useful solution, if that's suitable for the use-case. But that's easy enough to do already with something like random.uniform(key, size).astype('float8'). It's not clear to me that we should default to that approach if you pass float8 to uniform.

Yeah that makes sense from a design perspective. The downside of having to rely on astype/casting at the moment is that I don't think that it dovetails/works with jmp/the jax mixed precision library which allows for setting the dtype for neural net modules in an organised fashion. I agree that it would be nice to have a way to directly generate in fp8 as while it may not be an issue for a neural network that gets initialised only once at the beginning it would definitely negate any gains for applications such as neuroevolution if it's still sampling in fp16 under the hood and then casting.

kmheckel commented 1 year ago

I haven't done a deep dive on the dtype internals but is there a reason why the jnp.float8 types work when casting with astype but not when specifying them as the dtype of an array from the start?

Thank you so much for giving this attention!

`>>> import jax.numpy as jnp

import jax jax.random.rademacher(jax.random.PRNGKey(0), [3,3]).astype(jnp.float8_e4m3fn) Array([[1, -1, 1], [1, -1, 1], [1, 1, -1]], dtype=float8_e4m3fn) jax.random.rademacher(jax.random.PRNGKey(0), [3,3], dtype=jnp.float8_e4m3fn) Traceback (most recent call last): File "", line 1, in File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/random.py", line 1780, in rademacher return _rademacher(key, shape, dtype) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 208, in cache_miss outs, out_flat, out_tree, args_flat = _python_pjit_helper( File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, *params) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2633, in bind return self.bind_with_trace(top_trace, args, params) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(tracers, params) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl compiled = _pjit_lower( File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1177, in _pjit_lower return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, *kwargs) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1237, in _pjit_lower_cached return pxla.lower_sharding_computation( File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, **kwargs) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2072, in lower_sharding_computation nreps, tuple_args) = _cached_lowering_to_hlo( File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1919, in _cached_lowering_to_hlo lowering_result = mlir.lower_jaxpr_to_module( File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 617, in lower_jaxpr_to_module lower_jaxpr_to_fun( File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 932, in lower_jaxpr_to_fun out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1044, in jaxpr_subcomp in_nodes = map(read, eqn.invars) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1019, in read return ir_constants(v.val, canonicalize_types=True) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 217, in ir_constants out = handler(val, canonicalize_types) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 295, in _ndarray_constant_handler return _numpy_array_constant(val, canonicalize_types) File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 249, in _numpy_array_constant attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) jax._src.traceback_util.UnfilteredStackTrace: ValueError: cannot include dtype '4' in a buffer

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "", line 1, in File "/home/legion/.local/lib/python3.10/site-packages/jax/_src/random.py", line 1780, in rademacher return _rademacher(key, shape, dtype) ValueError: cannot include dtype '4' in a buffer `

jakevdp commented 1 year ago

It's because jax.random doesn't know about float8 dtypes, but jax.numpy does. They're still relatively new and experimental, and not fully supported across the JAX package.

kmheckel commented 1 year ago

Gotchya - while direct initialization of float8s seems like a tough problem would at least be possible to support float8 types beyond just jax.numpy so that neural networks could at least perform inference in float8? It seems like this would require enabling them in the MLIR interpreter... I know these types are new but they're supported on the H100 so they'll be around for years and it would be nice to not have to deal with NVIDIA's transformer engine package to bolt-on float8 support.

Thank you so much for your time!

jakevdp commented 1 year ago

Sure, it’s possible to support float8 types in other places, and they’re already being used in neural network training. It seems like my recommendation of generating random numbers in float32 and then casting to float8 is what you want, and that’s currently a supported operation. Does that not address your issue?

kmheckel commented 1 year ago

@jakevdp , Initialising/maintaining the parameters in higher precision works for the time being, the issue remains that when trying to use float8s for compute/outputs there is still a hang up that prevents things from running.

I know this involves the Haiku/JMP but I believe the blockage remains within JAX here.

policy I'm trying to use:

policy = jmp.Policy(compute_dtype=jnp.float8_e4m3fn,
                       param_dtype=jnp.float16,
                       output_dtype=jnp.float8_e4m3fn)

last part of error trace when attempting to initialise the neural net:

~/.local/lib/python3.10/site-packages/spyx/nn.py in __call__(self, x, V)
    175             beta = hk.get_parameter("b", self.hidden_shape, 
    176                                 init=hk.initializers.TruncatedNormal(0.25, 0.5))
--> 177             beta = jax.nn.hard_sigmoid(beta)
    178 
    179         spikes = self.act(V - self.threshold)

    [... skipping hidden 18 frame]

~/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py in _numpy_array_constant(x, canonicalize_types)
    247     x = x.view(np.uint16)
    248   x = np.ascontiguousarray(x)
--> 249   attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
    250   return (hlo.ConstantOp(attr).result,)
    251 

ValueError: cannot include dtype '4' in a buffer

I can try to conjure up a minimal example if that would be helpful.

jakevdp commented 1 year ago

I see, thanks. This is an example of what I mentioned – FP8 is still experimental and not fully supported throughout JAX (mainly because there are not supported operations on all hardware). Here's a minimal example of what you ran into:

import jax
import jax.numpy as jnp

x = jnp.array(0, dtype='float8_e4m3fn')
print(x + 1)
# 1.0

print(jax.jit(lambda x: x + 1)(x))
# ...
# ValueError: cannot include dtype '4' in a buffer
kmheckel commented 1 year ago

So just to be clear the intent is that float8 will remain in its current experimental state and be limited to jax.numpy for the time being until there is broader hardware support for it? If that's the case then there's nothing more to address at this time.

Thank you for taking the time to respond and follow up on this!

jakevdp commented 1 year ago

I'm actually not sure what the intent is, but it's probably safe to assume that experimental narrow-width dtypes will have limited support at least for the near future.

jakevdp commented 1 year ago

16696 fixes the float8+jit issue you saw above. There are still operations that are not supported for float8, but this gets us a step closer.

MoFHeka commented 1 year ago

kernel_init = lambda key,shape,dtype : jax.nn.initializers.normal(stddev=self.config.initializer_range)(key,shape,jnp.float32).astype(self.dtype) It could fix kernel_init problem. But I found my H100 running FP8 and FP16 with same speed!

jakevdp commented 10 months ago

I'm going to close this – I think it's expected that for the foreseeable future float8 will only be supported for certain operations.