mortont / axon_onnx

Easily convert models between ONNX and Axon
Apache License 2.0
95 stars 20 forks source link

Cannot import BERT #46

Closed dmorn closed 1 year ago

dmorn commented 1 year ago

Version 8f192d65143385621686758cb3351ad32b0ab8e2 was capable of importing BERT models by calling something like

AxonOnnx.import("BERT.onnx", sequence: 512, batch: 1)

Master branch now fails with a

** (FunctionClauseError) no function clause matching in anonymous fn/2 in Axon.split_inputs/2

    The following arguments were given to anonymous fn/2 in Axon.split_inputs/2:

        # 1
        #Nx.Tensor<
          s64[1]
          [12]
        >

        # 2
        %{
          62 => %Axon.Node{
            id: 62,
            name: #Function<89.114997104/2 in Axon.unique_identifiers/2>,
            parent: [],
            parameters: [],
            args: [],
            op: :constant,
            policy: p=f32 c=f32 o=f32,
            hooks: [],
            opts: [
              value: #Nx.Tensor<
                s64[1]
                [1]
              >
            ],
            op_name: :constant,
            stacktrace: [
              {Axon, :layer, 3, [file: 'lib/axon.ex', line: 272]},
              {AxonOnnx.Deserialize, :recur_nodes, 2,
               [file: 'lib/axon_onnx/deserialize.ex', line: 1484]},
              {Enum, :"-reduce/3-lists^foldl/2-0-", 3,
               [file: 'lib/enum.ex', line: 2468]},
              {AxonOnnx.Deserialize, :graph_to_axon, 2,
               [file: 'lib/axon_onnx/deserialize.ex', line: 44]},
              {AxonOnnx.Deserialize, :to_axon, 2,
               [file: 'lib/axon_onnx/deserialize.ex', line: 27]},
              {:elixir, :"-eval_external_handler/1-fun-2-", 4,
               [file: 'src/elixir.erl', line: 298]}
            ]
          },
          63 => %Axon.Node{
            id: 63,
            name: #Function<89.114997104/2 in Axon.unique_identifiers/2>,
            parent: [],
            parameters: [],
            args: [],
            op: :constant,
            policy: p=f32 c=f32 o=f32,
            hooks: [],
            opts: [
              value: #Nx.Tensor<
                s64[1]
                [512]
              >
            ],
            op_name: :constant,
            stacktrace: [
              {Axon, :layer, 3, [file: 'lib/axon.ex', line: 272]},
              {AxonOnnx.Deserialize, :recur_nodes, 2,
               [file: 'lib/axon_onnx/deserialize.ex', line: 1484]},
              {Enum, :"-reduce/3-lists^foldl/2-0-", 3,
               [file: 'lib/enum.ex', line: 2468]},
              {AxonOnnx.Deserialize, :graph_to_axon, 2,
               [file: 'lib/axon_onnx/deserialize.ex', line: 44]},
              {AxonOnnx.Deserialize, :to_axon, 2,
               [file: 'lib/axon_onnx/deserialize.ex', line: 27]},
              {:elixir, :"-eval_external_handler/1-fun-2-", 4,
               [file: 'src/elixir.erl', line: 298]}
            ]
          }
        }

    (axon 0.2.0) lib/axon.ex:298: anonymous fn/2 in Axon.split_inputs/2
    (elixir 1.14.0) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (elixir 1.14.0) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.4.0) lib/nx/container.ex:77: Nx.Container.Tuple.traverse/3
    (axon 0.2.0) lib/axon.ex:298: Axon.split_inputs/2
    (axon 0.2.0) lib/axon.ex:261: Axon.layer/3
    (axon 0.2.0) lib/axon.ex:1962: Axon.concatenate/2

I bisected the tree and apparently

229448d8deb2dc53c7cb733095bc1ab14b4121ef is the first bad commit
commit 229448d8deb2dc53c7cb733095bc1ab14b4121ef
Author: Sean Moriarity <smoriarity.5@gmail.com>
Date:   Wed Sep 7 06:55:35 2022 -0400

    Remove ignore_batch logic

 lib/axon_onnx/deserialize.ex | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

even though it fails with a different error:

** (ArgumentError) cannot reshape, current shape {1, 1, 512, 12, 64} is not compatible with new shape {1, 1, 1, 768}
    (nx 0.3.0) lib/nx/shape.ex:184: Nx.Shape.reshape/2
    (nx 0.3.0) lib/nx.ex:2281: Nx.reshape/3
    (axon 0.2.0) lib/axon/compiler.ex:537: Axon.Compiler.layer_predict_fun/14
    (axon 0.2.0) lib/axon/compiler.ex:614: Axon.Compiler.layer_init_fun/9
    (axon 0.2.0) lib/axon/compiler.ex:132: Axon.Compiler.call_init_cache/5
    (axon 0.2.0) lib/axon/compiler.ex:591: anonymous fn/4 in Axon.Compiler.layer_init_fun/9
    (elixir 1.14.0) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    iex:1: (file)
seanmor5 commented 1 year ago

With current main I am able to run:

python3 -m transformers.onnx --model=bert-base-cased models/

And successfully import with:

AxonOnnx.import("models/model.onnx")

Can you confirm this?

dmorn commented 1 year ago

Confirmed! 🤩 Thanks!

snewcomer commented 1 year ago

even though it fails with a different error:

Was this still a problem you encountered @dmorn? Running with the latest tags for various dependencies from the Semantic Search blog post, I get this

cannot reshape, current shape {2, 120, 768} is not compatible with new shape {128, 120, 12, 64}

Any thoughts on how to solve?

dmorn commented 1 year ago

Hey @snewcomer! If you're playing with transformer models I suggest you to check https://github.com/elixir-nx/bumblebee . I stopped trying the onnx->axon approach as it produces a model that you're not going to be able to serialize afterwards (at the time I opened this issue). I did not dig into this any further though.