mortont / axon_onnx

Easily convert models between ONNX and Axon
Apache License 2.0
95 stars 20 forks source link

Add BERT support #17

Closed dmorn closed 2 years ago

dmorn commented 2 years ago

My inspection is based on the ONNX model of huggingface's bert-base-german-cased, converted using the python -m transformers.onnx --model=bert-base-german-cased onnx/ command as described in https://huggingface.co/docs/transformers/serialization.

After decoding the model from the protobuf file, I tried the conversion:

iex(6)> AxonOnnx.Deserialize.graph_to_axon(g, %{"batch" => 32, "sequence" => 512})
** (RuntimeError) unsupported Slice op in graph
    (axon_onnx 0.1.0) lib/axon_onnx/deserialize.ex:419: anonymous fn/3 in AxonOnnx.Deserialize.get_nodes/4
    (elixir 1.13.2) lib/enum.ex:2396: Enum."-reduce/3-lists^foldl/2-0-"/3
    (axon_onnx 0.1.0) lib/axon_onnx/deserialize.ex:64: AxonOnnx.Deserialize.graph_to_axon/2

I checked if other operations are missing against the ones found in get_nodes

g.node |> Enum.map(fn %{op_type: op} -> op end) |> Enum.uniq
["Unsqueeze", "Cast", "Constant", "Sub", "Mul", "Shape", "Gather", "Add",
 "Slice", "ReduceMean", "Pow", "Sqrt", "Div", "MatMul", "Concat", "Reshape",
 "Transpose", "Softmax", "Erf", "Gemm", "Tanh"]

and Slice is the only operation that is missing (apparently).

dmorn commented 2 years ago

I'm checking out if I'm able to implement it myself!

dmorn commented 2 years ago

This is an Slice node example I have under my hands

%Onnx.NodeProto{
  __uf__: [],
  attribute: [],
  doc_string: "",
  domain: "",
  input: ["embeddings.position_ids", "1610", "218", "1611", "220"],
  name: "Slice_14",
  op_type: "Slice",
  output: ["221"]
}
dmorn commented 2 years ago

Stuck at https://github.com/elixir-nx/nx/issues/511 for now, I could not run the tests.

dmorn commented 2 years ago

@seanmor5 I successfully executed the tests, I'm now able to contribute. How would you proceed? Do you have any feedback on this regard?

dmorn commented 2 years ago

@seanmor5 I understood test generation and added a failing slice test. What Axon operation would be a candidate for implementing Slice?

josevalim commented 2 years ago

@dmorn Axon has a Nx layer where you can invoke any Nx function. Maybe that's the way to go here?

seanmor5 commented 2 years ago

I think Slice will be harder because the slice sizes can be runtime values which is not supported right now by Nx, you can try though and use the constant! helper in deserialize to force the sizes to be constant

dmorn commented 2 years ago

@josevalim and @seanmor5, thanks for joining! Here is what I'm doing

  # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#slice
  defp to_axon_slice(node = %Node{op_type: "Slice"}, axon, params, used_params) do
    parent = Enum.map(node.input, &axon!(&1, axon))
    [output_name] = node.output

    op = fn data, starts, ends, axes, steps ->
      [axes, starts, ends] # TODO(dmorn): include step
      |> Enum.map(&Nx.to_flat_list/1)
      |> Enum.zip()
      |> Enum.reduce(data, fn {axis, starts, ends}, acc ->
        Nx.slice_axis(acc, starts, ends - starts + 1, axis)
      end)
    end

    # At this point I don't know the output shape yet
    layer = Axon.layer(parent, op, {}, %{}, output_name)
    updated_axon = Map.put(axon, output_name, layer)
    {updated_axon, used_params}
  end

@seanmor5 are you referring to the output shape thing with the usage of constant! ? What does it mean to avoid specifying the output_shape like I'm doing?

dmorn commented 2 years ago

The above was the idea but I bet I cannot use whatever I want within op right? 😆

** (Axon.CompilerError) error while building prediction for #Function<17.30279046/5 in AxonOnnx.Deserialize.to_axon_slice/4> layer with name y:

     ** (ArgumentError) cannot invoke to_binary/2 on Nx.Defn.Expr.

     This typically means you are invoking an unsupported Nx function
     inside `defn` or inside JIT compiled code
dmorn commented 2 years ago

OK but I bet I can translate this thing into Nx only operations. Is the big picture correct though?

seanmor5 commented 2 years ago

@dmorn i think you're going down the right path! The problem is that ONNX is more flexible with shapes whereas Nx is not, but I think we can still make this work

dmorn commented 2 years ago

Is there any way to convert a tensor to a List within op?

josevalim commented 2 years ago

A tensor of what to a list of what? :) Do you have an example?

Btw, we do have a Slack channel where we hang, in case you want to join. You just need to create an account on the ErlEF website (it is free) and once you login you can request an invite in your settings page. We are all in the #machine-learning channel.

dmorn commented 2 years ago

To be used as parameter to Nx.slice/4

josevalim commented 2 years ago

If you have a tensor t of type s64[4] and you want to use those as the indexes positions, you can do this:

Nx.slice(other_tensor, [t[0], t[1], t[2], t[3]], [2, 2, 2, 2])

The lengths cannot be dynamic though.

josevalim commented 2 years ago

There are other functions like Nx.gather, Nx.take_along_axis and so on that may be helpful here.

dmorn commented 2 years ago

I'll do it asap, thanks @josevalim 😊

seanmor5 commented 2 years ago

See #22