Closed chaoqu closed 10 months ago
I will look more into this later, but I'm almost sure the problem doesn't lie within top_k because it is basically just a delegation call to argsort+take_along_axis.
Could you send the full code for reproducing this? I'm getting a different error with what I thought was the correct reproduction code.
I believe this is an Nx bug, not an Axon bug
Background: this issue arises from a slack conversation in the
nx
channel (https://elixir-lang.slack.com/archives/C01M2B96EF9/p1688418081101709)The issue seems to be that
Nx.top_k
does not support backpropagation. For example this code doesn't work:The error I get is:
However, replacing
Nx.top_k
withNx.argmax
fixes the issue: