jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

Disentangle "jax_enable_x64" and default dtype #22688

Open francois-rozet opened 4 months ago

francois-rozet commented 4 months ago

Description

Providing a way to enable float64 without changing the default dtype/precision. Or even better, adding the option to choose the default dtype. For example,

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_dtype', jnp.float32)

Motivation

Some algorithms, such as fast Fourier transforms and conjugate gradient methods, are unstable at 32-bit precision. It is therefore recommended to apply them at 64-bit precision. If such algorithms are part of some bigger procedure, the entire procedure must run at 64-bit precision. This is very wasteful, especially if the procedure involves neural networks.

The current solution is to rewrite the procedure while explicitly requesting the float32 dtype at every array creation. For large code bases, this is very tedious and it prevents the use of most libraries.

Alternatives

An even better solution would be to be able to set the dtype locally (e.g. with contexts), but I suspect that this would not play well with tracers.

jakevdp commented 4 months ago

We attempted this in the past (see the jax_default_dtype_bits configuration). It still exists, and still kind of works, but it's entirely undocumented and untested so I wouldn't rely on it. We abandoned the approach because it didn't seem worth the maintenance cost (doubling all our CI costs to check that default dtypes are respected).

jax_enable_x64 is a really problematic API, but we haven't been able to come up with a good solution that doesn't significantly impact current users who rely on the guarantees of the default setting.

@yashk2810 has been looking at a way to enable X64 locally that side-steps the problems with past approaches that have tried to do this. I think in the long run this will be a better fix than trying to support yet another global flag.

francois-rozet commented 4 months ago

Hello @jakevdp, thank you for your answer. Actually, why is float64 not available by default? I am asking because an easy solution would be to enable (not set as default) the 64-bit precision by default and remove that pesky UserWarning altogether. This would not impact current users as it would simply allow something that was not possible in the past.

The only issue I see is the current casting of 64-bit NumPy arrays to 32-bit JAX arrays, which would become slightly weird.

The same could be done with yet another global flag (e.g. jax_enable_x64_lazy), which would not increase the CI cost, because it does not affect the default dtype.

jakevdp commented 4 months ago

Actually, why is float64 not available by default?

Are you asking why jax_enable_x64 is false by default? If so, the answer is one of JAX's original goals was fast computation on accelerators (GPU, TPU) which don't support 64-bit values. It was a decision made early in the project's development, and defaults are hard to change – especially global defaults that affect virtually every API – because changing them breaks every user.

francois-rozet commented 4 months ago

No, my question is why is it prohibited to use float64 by default?

jakevdp commented 4 months ago

That's the question I answered: jax_enable_x64 is the flag that lets you control whether 64-bit types are allowed to be used

francois-rozet commented 4 months ago

My question is ambiguous, let me reformulate. Currently, jax_enable_x64 is the flag that enables to use 64-bit types and sets 64-bit types as the default.

Assuming that the default remain 32-bit types, is there a reason to disable 64-bit types? In PyTorch, the default is also 32-bit types, but 64-bit types are not prohibited.

especially global defaults that affect virtually every API – because changing them breaks every user.

I don't think that enabling 64-bit types without setting them as default (not what jax_enable_x64 does currently) would impact users.

jakevdp commented 4 months ago

jax_enable_x64 technically does nothing to change the default output of functions.

jax_enable_x64=False means that any time a 64-bit value would be produced, it is downcast to a 32-bit value.

Some APIs do produce 64-bit outputs by default, for example jnp.float64(1) produces a float64 value, and jax_enable_x64=False truncates this to a 32-bit value instead. But it has nothing to do with default dtypes per se.

Does that make sense?

Assuming that the default remain 32-bit types, is there a reason to disable 64-bit types?

Yes: there is very good reason to disable 64-bit types entirely. Any 64-bit value that enters into a computation on an accelerator leads to either an error (for hardware that does not support 64-bit) or a slow computation (in general for hardware that supports it). NumPy-style APIs often generate 64-bit types by default. For example, if you add an int32 and a uint32, you get an int64: and this happens silently. Similarly, Python scalars are 64-bit: so when you pass 1.0 to a function that converts it to an array, the result is a float64. This is important to realize: 64-bit values tend to sneak into programs, no matter how careful you are about your defaults. It's the nature of NumPy and Python.

In the early days of JAX, this led to a lot of problems when trying to run neural network models on accelerators, and so the jax_enable_x64 flag was added to make sure that 64-bit values never enter the computation. Was it a good design decision? Probably not, but it solved the problem of the moment, and allowed the demo to run correctly that week, and paved the way for what JAX has become today.

Can we change this decision now? Yes and no. The problem is that as soon as you relax the behavior of the X64 flag so that it allows 64-bit values, then those values tend to sneak in at unintended places. This breaks code for important stakeholders. It's really hard to make changes to these kinds of defaults without breaking people in subtle or not-so-subtle ways: I spent nearly a year attempting to do that a while back, and I ended up having to give up on the project due to its infeasibility.

Does that help clarify things?

francois-rozet commented 4 months ago

Thank you for your answer! So the problem is casting values from Python and NumPy to JAX, as I had guessed. Then a solution based on contexts is probably better. I would be happy to help @yashk2810.

yashk2810 commented 4 months ago

I have a prototype here: https://github.com/google/jax/pull/21472 but it still has some problems I need to figure out.