google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.92k stars 2.73k forks source link

float0 should support addition, subtraction, and scalar multiplication #12339

Open carlosgmartin opened 2 years ago

carlosgmartin commented 2 years ago

Description

I have some modules whose parameters include discrete values (e.g. indices). After turning on allow_int=True in jax.grad to handle these, I get the following error:

TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

The problem comes from multiplication of the gradients of the discrete parameters (which happen to belong to the trivial vector space) with the learning rate.

It is very inconvenient for float0 to not work like the other float types in this situation.

It is also mathematically incorrect for it not to. The trivial vector space contains only a zero vector. It is perfectly legitimate to scale this zero vector by any scalar quantity: Doing so just happens to yield the zero vector again. Same goes for addition/subtraction: Adding/subtracting two zero vectors just yields the zero vector again.

Furthermore, it would be convenient if one could add a float0 to a discrete value (simply yielding the latter), so that optimizers' parameter updates of the form param_new = param_old + lr * grad can work without modification. Mathematically, this is justified by the fact that the trivial vector space is a subspace of every vector space, and the trivial module (which it is equivalent to) is a submodule of every module.

In short, a float0 should act like an absorber/annihilator under multiplication and an identity under addition.

Example:

import jax

def f(i):
    return jax.numpy.array([.8, .9])[i].sum()

fp = jax.grad(f, allow_int=True)
i = jax.numpy.array([0, 0, 1])
fp_i = fp(i)
print(fp_i.dtype)
print(fp_i.shape)

print()
try:
    print(fp_i * .3)
except Exception as e:
    print(e)

print()
try:
    print(fp_i + fp_i)
except Exception as e:
    print(e)

print()
try:
    print(i + fp_i)
except Exception as e:
    print(e)

Output:

[('float0', 'V')]
(3,)

ufunc 'multiply' did not contain a loop with signature matching types (dtype([('float0', 'V')]), dtype('float64')) -> None

ufunc 'add' did not contain a loop with signature matching types (dtype([('float0', 'V')]), dtype([('float0', 'V')])) -> None

Called add with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

What jax/jaxlib version are you using?

jax v0.3.17, jaxlib v0.3.15

Which accelerator(s) are you using?

CPU

Additional System Info

Python 3.10.5, macOS 11.6.8

jakevdp commented 2 years ago

Thanks for the question! I'm assigning @mattjj because he knows some of the context of discussions around these kinds of ideas.

carlosgmartin commented 1 year ago

More concisely: Let zero be a value of float0 type and any be a value of any type. Then the following (and perhaps more) ought to hold: