elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.55k stars 104 forks source link

`Axon.Display.as_graph` breaks when updating axon from 0.5 to 0.6 #594

Closed nickgnd closed 1 month ago

nickgnd commented 2 months ago

Hey there 👋 First of all, thanks again for all the work and effort you all are putting into the Nx libraries, i'm truly amazed by your results.

Then, I was updating the Nx dependencies in a Livebook of mine and I spotted an issue when using Axon.Display.as_graph.

Here the snippet to reproduce it:

Mix.install(
  [
    {:nx, "~> 0.6.1", override: true},
    {:axon, "~> 0.6"},
    {:exla, "~> 0.6.0"},
    {:kino, "~> 0.14"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

batch_size = 4
block_size = 8
shape = {batch_size, block_size}

bigram_model =
  Axon.input("sequence", shape: shape)
  |> Axon.embedding(65, 65)

Axon.Display.as_graph(bigram_model, Nx.template(shape, :f32),
  direction: :top_down
)

It fails with the following error:

** (Axon.CompileError) exception found when compiling layer Axon.Layers.embedding/3 named embedding_0:

    ** (ArgumentError) indices must be an integer tensor, got {:f, 32}
        (nx 0.6.4) lib/nx.ex:14150: Nx.take/3

(pass debug: true to build/compile see where the layer was defined)

Compiling of the model was initiated at:

    (axon 0.6.1) lib/axon.ex:3406: anonymous fn/3 in Axon.get_output_shape/3
    (nx 0.6.4) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (axon 0.6.1) lib/axon/defn.ex:14: Axon.Defn.__jit__/5
    (nx 0.6.4) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
    (axon 0.6.1) lib/axon.ex:3409: Axon.get_output_shape/3
    #cell:5i2jifureixx5mee:11: (file)

I tried to debug it further, but I couldn't discover much.

I think the Nx.take/3 that is raising the exception is probably called here in Axon.Layers.embedding/3, which it is maybe called here in Axon.embedding/4.... 🤔

☝️ the same snippet used to work correctly with Axon v0.5.1 (see screenshot)

Mix.install(
  [
    {:nx, "~> 0.6.1", override: true},
    {:axon, "~> 0.5.1"},
    {:exla, "~> 0.6.0"},
    {:kino, "~> 0.14"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)
image

Let me know if I can help in any way, cheers ✌️

seanmor5 commented 2 months ago

Can you try from the main branch?

nickgnd commented 2 months ago

Hey @seanmor5 I just tried, but no luck. I got the same error

Mix.install(
  [
    {:nx, "~> 0.6.1", override: true},
    {:axon, git: "git@github.com:elixir-nx/axon.git", ref: "main"},
    {:exla, "~> 0.6.0"},
    {:kino, "~> 0.14"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

# ....

Error:

** (Axon.CompileError) exception found when compiling layer Axon.Layers.embedding/3 named embedding_0:

    ** (ArgumentError) indices must be an integer tensor, got {:f, 32}
        (nx 0.6.4) lib/nx.ex:14150: Nx.take/3

(pass debug: true to build/compile to see where the layer was defined)

Compiling of the model was initiated at:

    (axon 0.6.1) lib/axon.ex:3650: anonymous fn/3 in Axon.get_output_shape/3
    (nx 0.6.4) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (axon 0.6.1) lib/axon/defn.ex:14: Axon.Defn.__jit__/5
    (nx 0.6.4) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
    (axon 0.6.1) lib/axon.ex:3655: Axon.get_output_shape/3
    #cell:gef2wlvqmz5hlji6:10: (file)
seanmor5 commented 1 month ago

@nickgnd I just looked again, and I think the issue is that you are passing an :f32 template tensor as input. When using the embedding layer, the inputs need to be an integer type. If you change Nx.template(shape, :f32) to Nx.template(shape, :u32) can you let me know if it works?

nickgnd commented 1 month ago

@seanmor5 you're totally right and it makes a lot of sense! It works indeed 🎉

image

And sorry, I should pay more attention to error message 🙈