elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.52k stars 100 forks source link

`Nx.top_k` doesn't seem to support backpropagation. #509

Closed chaoqu closed 10 months ago

chaoqu commented 1 year ago

Background: this issue arises from a slack conversation in the nx channel (https://elixir-lang.slack.com/archives/C01M2B96EF9/p1688418081101709)

The issue seems to be that Nx.top_k does not support backpropagation. For example this code doesn't work:

model =
  input # <---- this is a 4x8 Tensor
  |> Axon.embedding(65, 65, name: "embedding", kernel_initializer: :zeros) # <--- this should give me a 4x8x65
  |> Axon.sigmoid() # <--- should still be 4x8x65
  |> Axon.nx(fn result ->
    {_r, index} = Nx.top_k(result, k: 1) # <-- i am taking the index of the max value in the "65" dimension. 
    result1 = index |> Nx.squeeze() <--- I should get a 4x8 
    result1
  end)

The error I get is:

** (ArgumentError) invalid padding configuration, rank of padding configuration and shape must match
    (nx 0.5.3) lib/nx/shape.ex:1008: Nx.Shape.padded_dims/3
    (nx 0.5.3) lib/nx/shape.ex:991: Nx.Shape.pad/2
    (nx 0.5.3) lib/nx.ex:3034: Nx.pad/3
    (nx 0.5.3) Elixir.Nx.Defn.Grad.erl:469: Nx.Defn.Grad.grad/4
    (nx 0.5.3) Elixir.Nx.Defn.Grad.erl:343: Nx.Defn.Grad.update_grads/6
    (nx 0.5.3) Elixir.Nx.Defn.Grad.erl:207: Nx.Defn.Grad.recur_to_grad/4
    (nx 0.5.3) Elixir.Nx.Defn.Grad.erl:196: Nx.Defn.Grad.recur_to_grad/4
    iex:11: (file)

However, replacing Nx.top_k with Nx.argmax fixes the issue:

model =
    input # <---- this is a 4x8 Tensor
    |> Axon.embedding(65, 65, name: "embedding", kernel_initializer: :zeros) # <--- this should give me a 4x8x65
    |> Axon.sigmoid() # <--- should still be 4x8x65
    |> Axon.nx(fn result ->
      Nx.argmax(result, axis: -1)
    end)
polvalente commented 1 year ago

I will look more into this later, but I'm almost sure the problem doesn't lie within top_k because it is basically just a delegation call to argsort+take_along_axis.

polvalente commented 1 year ago

Could you send the full code for reproducing this? I'm getting a different error with what I thought was the correct reproduction code.

seanmor5 commented 10 months ago

I believe this is an Nx bug, not an Axon bug