Closed josevalim closed 3 years ago
JAX also has custom_root
and custom_linear_solve
, but I'm not sure of the particular use cases of those for right now. I think custom_grad
and stop_grad
are enough for now.
I wanted to take a crack at this but I think it's a little out of my scope 😅. I'm not as comfortable with macros as I'd like to be to implement this. For someone who does end up working on this, I wrote these test cases if they want to test their implementation (the following just provides two examples of a custom gradient that operate better than autograd):
describe "custom gradients" do
# Numerical stability example
defn log1pexp(t) do
res = Nx.log1p(Nx.exp(t))
transform(res, &IO.inspect/1)
res
end
defn grad_log1pexp(t) do
# Autograd computes a derivative that cannot be calculated.
# This computation will usually result in a 0 times an infinity
# when querying a large number.
res = grad(t, log1pexp(t))
transform(res, &IO.inspect/1)
res
end
defn manual_grad_log1pexp(t) do
# We can define a custom gradient with better behaviour;
# that is, we can define one that doesn't blow up asymptotically
# or doesn't involve the computation of a 0 times an infinity,
# or dividing by 0 or infinity.
res = Nx.subtract(1.0, Nx.divide(1.0, Nx.add(1.0, Nx.exp(t))))
transform(res, &IO.inspect/1)
res
end
# Custom differentiation domain example
defn r1psqrt(t) do
res = Nx.divide(t, Nx.add(1.0, Nx.sqrt(t)))
transform(res, &IO.inspect/1)
res
end
defn grad_r1psqrt(t) do
# Sometimes, we want to obtain a derivative different from our
# autograd because we want the derivative's domain to be different
# than the one assumed by the autograd machinery.
# The above function's domain is only the non-negative numbers, [0, +/infty).
# This means that this function is also differentiable at 0, see
# [Rudin’s Principles of Mathematical Analysis Definition 5.1, or
# Tao’s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6].
# However, autograd will generate a derivative that is NOT differentiable at 0
# because it is unable to simplify and rewrite the derivative it generates.
res = grad(t, r1psqrt(t))
transform(res, &IO.inspect/1)
res
end
defn manual_grad_r1psqrt(t) do
# We, as humans, can take the derivative that autograd generates and simplify it
# further to be one that is differentiable at 0 and enforce our own desired
# differentiation convention.
res = Nx.divide(
Nx.add(Nx.sqrt(t), 2.0),
Nx.multiply(2.0, Nx.power(Nx.add(Nx.sqrt(t), 1.0), 2)))
transform(res, &IO.inspect/1)
res
end
test "checks custom gradient against autograd gradient for log1pexp" do
# Check that the gradients are equivalent at large numbers
check_grads!(&log1pexp/1, &grad_log1pexp/1, 709)
check_grads!(&log1pexp/1, &manual_grad_log1pexp/1, 709)
# Assert equality of autograd and custom gradient for large numbers
assert grad_log1pexp(709) == manual_grad_log1pexp(709)
end
test "checks custom gradient against autograd gradient for r1psqrt" do
# check_grads!/5 will not work on the following functions because
# running finite differences will lead to taking the square root of
# a negative number.
# check_grads!(&r1psqrt/1, &grad_r1psqrt/1, 0)
# check_grads!(&r1psqrt/1, &manual_grad_r1psqrt/1, 0)
# Running the autograd for this function at 0 will not work because
# autograd will not simplify the generated derivative of this function
# to one where evaluating it at 0 will not generate a 0 dividing by 0 error.
# assert grad_r1psqrt(0) == manual_grad_r1psqrt(0)
# Our custom gradient function will differentiate better than autograd
# because we can rewrite the derivative to operate on a different
# domain.
assert manual_grad_r1psqrt(0) == Nx.tensor(1.0)
end
end
These test cases are taken straight out of the JAX documentation for custom gradients: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#example-problems.
Thanks @shak360! I will break down how someone should implement this feature.
Add a Mix.Defn.Expr.metadata(expr, metadata)
. This will create a metadata specific node. Now you can do this:
defmacro stop_grad(expr) do
quote do
Nx.Defn.Kernel.transform(unquote(expr), fn expr ->
Nx.Defn.Expr.traverse_exprs(expr, &Expr.metadata(&1, %{stop_grad: true}))
end)
end
end
You will have to handle it:
Using the metadata above, add custom_grad too.
I think the first PR can be with step 1 and step 2. Then work on step 3.
Working on this because @seanmor5 needs it. :)
At least
custom_grad
andzero_grad
. Anything else @seanmor5?