Open jyc opened 1 month ago
I know that I can use
y = ~VEC[1 1] |> vectorize(:bar)
Foo.f_and_grad(x, y)
to get the result I expect, but in practice y
is actually quite large, so repeating it just so the gradient is computed properly seems wasteful. I will dig into that more though.
I think this makes sense because the grad is computed over y, but I would like to see if @polvalente has a different opinion.
I tried checking if it would still be efficient to broadcast y
to the size of x
in order to get a gradient with the same dimensions as y
; I wasn't sure whether Nx would create e.g. a vector with zero stride. However it looks like the byte_size
increases, at least with Nx.BinaryBackend and Nx.EXLABackend:
x = ~VEC[0 0] |> vectorize(:foo)
y = ~VEC[1]
[x, y] = Nx.broadcast_vectors([x, y])
y |> Nx.byte_size()
# 16
# if another elements are added to `x`, evaluates to 24, etc.
So I still would be interested if there is a way to get the non-summed gradient, although I understand if it's not possible with this API.
I agree with @jyc in that the grad should have the same vector shape as the output. That is, the correct result for the example should be [1.0, 1.0]
instead of 2.0
.
The mental model I have is that fun(vectorized[foo: 2] [1, 2])
should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo)
, which is not the case here.
Memory-wise, vectorization will end up doing the explicit broadcasting, if applicable, regardless of the backend (although some backends might end up fusing things).
Note: This specific comment is wrong and can be ignored; what I said earlier & what polvalente said is correct AFAIK. Sorry for the confusion!
@polvalente Thanks for the reply! Sorry but just to be clear, I checked after you mentioned the mental model and it looks like grad
returns the same result even without vectorization, so my mentioning the vectorization was a red herring:
defmodule Foo do
import Nx.Defn
defn f(x, y) do
x + y
end
defn f_and_grad(x, y) do
value_and_grad(y, fn y -> Foo.f(x, y) end)
end
end
x = ~VEC[0 1 2]
y = ~VEC[1]
Foo.f_and_grad(x, y)
# {~VEC[1, 2, 3], ~VEC[3]}
This is still surprising to me but at least it is consistent with and without vectorization. I will keep looking for a workaround.
Actually, I have confused myself! I don't believe it's a red herring because it's the other axis that is vectorized. I misunderstood. Please ignore my last comment, sorry for the noise. In other words, I agree with your comment here:
The mental model I have is that fun(vectorized[foo: 2] [1, 2]) should yield the same output as vectorize(stack([fun(1), fun(2)]), :foo), which is not the case here.
The problem here is that for that Foo
module, this isn't true:
x = Nx.tensor([0, 1, 2])
y = 1
{_, grad0} = Foo.f_and_grad(x[0], y)
{_, grad1} = Foo.f_and_grad(x[1], y)
{_, grad2} = Foo.f_and_grad(x[2], y)
expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)
actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
iex(19)> expected_result = Nx.stack([grad0, grad1, grad2]) |> Nx.vectorize(:foo)
#Nx.Tensor<
vectorized[foo: 3]
f32
[1.0, 1.0, 1.0]
>
iex(20)>
nil
iex(21)> actual_result = Foo.f_and_grad(Nx.vectorize(x, :foo), y)
{#Nx.Tensor<
vectorized[foo: 3]
s32
[1, 2, 3]
>,
#Nx.Tensor<
f32
3.0
>}
You are right! Sorry for the noise.
Reopening because we still need to support vectorize/devectorize inside the gradient. :)
Thanks for making Nx!
I tried to use
value_and_grad
on a function that takes two inputs: a vectorized tensor and a non-vectorized tensor.This evaluates to:
The value is correct and maintains the vectorized axis of the vectorized input to
x
, but the gradient surprises me. I would have expected a vectorized tensor rank-1 dimension-2 vector with the same:foo
axis and which is everywhere1
; it looks like instead Nx is summing up the two gradients.Is this behavior expected? If so, is there any way to make Nx return a vectorized gradient?
Thanks!