elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.65k stars 193 forks source link

Add grad helpers #148

Closed josevalim closed 3 years ago

josevalim commented 3 years ago

At least custom_grad and zero_grad. Anything else @seanmor5?

seanmor5 commented 3 years ago

JAX also has custom_root and custom_linear_solve, but I'm not sure of the particular use cases of those for right now. I think custom_grad and stop_grad are enough for now.

shak360 commented 3 years ago

I wanted to take a crack at this but I think it's a little out of my scope 😅. I'm not as comfortable with macros as I'd like to be to implement this. For someone who does end up working on this, I wrote these test cases if they want to test their implementation (the following just provides two examples of a custom gradient that operate better than autograd):

describe "custom gradients" do
    # Numerical stability example
    defn log1pexp(t) do
      res = Nx.log1p(Nx.exp(t))
      transform(res, &IO.inspect/1)
      res
    end

    defn grad_log1pexp(t) do
      # Autograd computes a derivative that cannot be calculated.
      # This computation will usually result in a 0 times an infinity
      # when querying a large number.
      res = grad(t, log1pexp(t))
      transform(res, &IO.inspect/1)
      res
    end

    defn manual_grad_log1pexp(t) do
      # We can define a custom gradient with better behaviour;
      # that is, we can define one that doesn't blow up asymptotically
      # or doesn't involve the computation of a 0 times an infinity,
      # or dividing by 0 or infinity.
      res = Nx.subtract(1.0, Nx.divide(1.0, Nx.add(1.0, Nx.exp(t))))
      transform(res, &IO.inspect/1)
      res
    end

    # Custom differentiation domain example
    defn r1psqrt(t) do
      res = Nx.divide(t, Nx.add(1.0, Nx.sqrt(t)))
      transform(res, &IO.inspect/1)
      res
    end

    defn grad_r1psqrt(t) do
      # Sometimes, we want to obtain a derivative different from our
      # autograd because we want the derivative's domain to be different
      # than the one assumed by the autograd machinery.
      # The above function's domain is only the non-negative numbers, [0, +/infty).
      # This means that this function is also differentiable at 0, see
      # [Rudin’s Principles of Mathematical Analysis Definition 5.1, or 
      # Tao’s Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. 
      # However, autograd will generate a derivative that is NOT differentiable at 0
      # because it is unable to simplify and rewrite the derivative it generates.
      res = grad(t, r1psqrt(t))
      transform(res, &IO.inspect/1)
      res
    end

    defn manual_grad_r1psqrt(t) do
      # We, as humans, can take the derivative that autograd generates and simplify it
      # further to be one that is differentiable at 0 and enforce our own desired
      # differentiation convention.
      res = Nx.divide(
        Nx.add(Nx.sqrt(t), 2.0),
         Nx.multiply(2.0, Nx.power(Nx.add(Nx.sqrt(t), 1.0), 2)))
      transform(res, &IO.inspect/1)
      res
    end

    test "checks custom gradient against autograd gradient for log1pexp" do
      # Check that the gradients are equivalent at large numbers
      check_grads!(&log1pexp/1, &grad_log1pexp/1, 709)
      check_grads!(&log1pexp/1, &manual_grad_log1pexp/1, 709)

      # Assert equality of autograd and custom gradient for large numbers
      assert grad_log1pexp(709) == manual_grad_log1pexp(709)
    end

    test "checks custom gradient against autograd gradient for r1psqrt" do
      # check_grads!/5 will not work on the following functions because
      # running finite differences will lead to taking the square root of
      # a negative number.
      # check_grads!(&r1psqrt/1, &grad_r1psqrt/1, 0)
      # check_grads!(&r1psqrt/1, &manual_grad_r1psqrt/1, 0)

      # Running the autograd for this function at 0 will not work because
      # autograd will not simplify the generated derivative of this function
      # to one where evaluating it at 0 will not generate a 0 dividing by 0 error.
      # assert grad_r1psqrt(0) == manual_grad_r1psqrt(0)

      # Our custom gradient function will differentiate better than autograd
      # because we can rewrite the derivative to operate on a different
      # domain.
      assert manual_grad_r1psqrt(0) == Nx.tensor(1.0)
    end
  end

These test cases are taken straight out of the JAX documentation for custom gradients: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#example-problems.

josevalim commented 3 years ago

Thanks @shak360! I will break down how someone should implement this feature.

Step 1: Add a metadata expr

Add a Mix.Defn.Expr.metadata(expr, metadata). This will create a metadata specific node. Now you can do this:

defmacro stop_grad(expr) do
  quote do
    Nx.Defn.Kernel.transform(unquote(expr), fn expr ->
      Nx.Defn.Expr.traverse_exprs(expr, &Expr.metadata(&1, %{stop_grad: true}))
    end)
  end
end

Step 2: Handle the new metadata node

You will have to handle it:

  1. In EXLA, which can just ignore the metadata node and continue traversing it up
  2. In Nx.Defn.Evaluator, which can just ignore the metadata node and continue traversing it up
  3. In Nx.Defn.Grad, which will look at the stop_grad key in the metadata, and stop if found, or continue traversing it otherwise

Step 3: Effectively add custom_grad

Using the metadata above, add custom_grad too.

I think the first PR can be with step 1 and step 2. Then work on step 3.

josevalim commented 3 years ago

Working on this because @seanmor5 needs it. :)