Open carlosgmartin opened 2 years ago
Thanks for the question! I'm assigning @mattjj because he knows some of the context of discussions around these kinds of ideas.
More concisely: Let zero
be a value of float0 type and any
be a value of any type. Then the following (and perhaps more) ought to hold:
zero + any == any
any + zero == any
zero * any == zero
any * zero == zero
-zero == zero
zero - any == zero + -any == -any
any - zero == any + -zero == any
abs(zero) == zero
Description
I have some modules whose parameters include discrete values (e.g. indices). After turning on
allow_int=True
in jax.grad to handle these, I get the following error:The problem comes from multiplication of the gradients of the discrete parameters (which happen to belong to the trivial vector space) with the learning rate.
It is very inconvenient for float0 to not work like the other float types in this situation.
It is also mathematically incorrect for it not to. The trivial vector space contains only a zero vector. It is perfectly legitimate to scale this zero vector by any scalar quantity: Doing so just happens to yield the zero vector again. Same goes for addition/subtraction: Adding/subtracting two zero vectors just yields the zero vector again.
Furthermore, it would be convenient if one could add a float0 to a discrete value (simply yielding the latter), so that optimizers' parameter updates of the form
param_new = param_old + lr * grad
can work without modification. Mathematically, this is justified by the fact that the trivial vector space is a subspace of every vector space, and the trivial module (which it is equivalent to) is a submodule of every module.In short, a float0 should act like an absorber/annihilator under multiplication and an identity under addition.
Example:
Output:
What jax/jaxlib version are you using?
jax v0.3.17, jaxlib v0.3.15
Which accelerator(s) are you using?
CPU
Additional System Info
Python 3.10.5, macOS 11.6.8