elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.65k stars 193 forks source link

Support vectorize/devectorize inside gradients #1533

Open jyc opened 1 month ago

jyc commented 1 month ago

Thanks for making Nx!

I tried to use value_and_grad on a function that takes two inputs: a vectorized tensor and a non-vectorized tensor.

defmodule Foo do
  import Nx.Defn
  defn f(x, y) do
    x + y
  end

  defn f_and_grad(x, y) do
    value_and_grad(y, fn y -> Foo.f(x, y) end)
  end
end

x = ~VEC[0 1] |> vectorize(:bar)
Foo.f_and_grad(x, 1)

This evaluates to:

{#Nx.Tensor<
   vectorized[bar: 2]
   s64
   EXLA.Backend<host:0, 0.731981912.321781778.128426>
   [1, 2]
 >,
 #Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.731981912.321781778.128427>
   2.0
 >}

The value is correct and maintains the vectorized axis of the vectorized input to x, but the gradient surprises me. I would have expected a vectorized tensor rank-1 dimension-2 vector with the same :foo axis and which is everywhere 1; it looks like instead Nx is summing up the two gradients.

Is this behavior expected? If so, is there any way to make Nx return a vectorized gradient?

Thanks!

jyc commented 1 month ago

I know that I can use

y = ~VEC[1 1] |> vectorize(:bar)
Foo.f_and_grad(x, y)

to get the result I expect, but in practice y is actually quite large, so repeating it just so the gradient is computed properly seems wasteful. I will dig into that more though.

josevalim commented 1 month ago

I think this makes sense because the grad is computed over y, but I would like to see if @polvalente has a different opinion.

jyc commented 1 month ago

I tried checking if it would still be efficient to broadcast y to the size of x in order to get a gradient with the same dimensions as y; I wasn't sure whether Nx would create e.g. a vector with zero stride. However it looks like the byte_size increases, at least with Nx.BinaryBackend and Nx.EXLABackend:

x = ~VEC[0 0] |> vectorize(:foo)
y = ~VEC[1]
[x, y] = Nx.broadcast_vectors([x, y])
y |> Nx.byte_size()
# 16
# if another elements are added to `x`, evaluates to 24, etc.

So I still would be interested if there is a way to get the non-summed gradient, although I understand if it's not possible with this API.

polvalente commented 1 month ago

I agree with @jyc in that the grad should have the same vector shape as the output. That is, the correct result for the example should be [1.0, 1.0] instead of 2.0.

The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.

polvalente commented 1 month ago

Memory-wise, vectorization will end up doing the explicit broadcasting, if applicable, regardless of the backend (although some backends might end up fusing things).

jyc commented 1 month ago

Note: This specific comment is wrong and can be ignored; what I said earlier & what polvalente said is correct AFAIK. Sorry for the confusion!

@polvalente Thanks for the reply! Sorry but just to be clear, I checked after you mentioned the mental model and it looks like grad returns the same result even without vectorization, so my mentioning the vectorization was a red herring:

defmodule Foo do
  import Nx.Defn
  defn f(x, y) do
    x + y
  end

  defn f_and_grad(x, y) do
    value_and_grad(y, fn y -> Foo.f(x, y) end)
  end
end

x = ~VEC[0 1 2]
y = ~VEC[1]
Foo.f_and_grad(x, y)
# {~VEC[1, 2, 3], ~VEC[3]}

This is still surprising to me but at least it is consistent with and without vectorization. I will keep looking for a workaround.

jyc commented 1 month ago

Actually, I have confused myself! I don't believe it's a red herring because it's the other axis that is vectorized. I misunderstood. Please ignore my last comment, sorry for the noise. In other words, I agree with your comment here:

The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.

polvalente commented 1 month ago

The problem here is that for that Foo module, this isn't true:

x = Nx.tensor([0, 1, 2])
y = 1

{_, grad0} = Foo.f_and_grad(x[0], y)
{_, grad1} = Foo.f_and_grad(x[1], y)
{_, grad2} = Foo.f_and_grad(x[2], y)

expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)

actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
iex(19)> expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)
#Nx.Tensor<
  vectorized[foo: 3]
  f32
  [1.0, 1.0, 1.0]
>
iex(20)>
nil
iex(21)> actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
{#Nx.Tensor<
   vectorized[foo: 3]
   s32
   [1, 2, 3]
 >,
 #Nx.Tensor<
   f32
   3.0
 >}
jyc commented 1 month ago

You are right! Sorry for the noise.

josevalim commented 1 month ago

Reopening because we still need to support vectorize/devectorize inside the gradient. :)