sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.23k stars 214 forks source link

Translating node #14 "/Add_1" Add ToTypedTranslator #1202

Closed igor-yusupov closed 1 year ago

igor-yusupov commented 1 year ago

Tried to run the model, but can't get it to read.

use rand::prelude::*;
use tract_onnx::{
    prelude::*,
    tract_hir::tract_ndarray::{Array, Array2, Array3, Array4, Dim},
};

fn main() {
    let decoder = tract_onnx::onnx()
        .model_for_path(format!("weights/decoder_base_fix_kv_cache.onnx"))
        .unwrap()
        .with_output_fact(0, InferenceFact::default())
        .unwrap()
        .into_optimized()
        .unwrap()
        .into_runnable()
        .unwrap();

    let mut rng = thread_rng();

    let tokens: Array2<i64> = Array2::from_shape_vec((1, 4), vec![1, 2, 3, 4]).unwrap();
    let tokens: Tensor = tokens.into();
    let shape = Dim([1, 1500, 512]);
    let audio_features: Array3<f32> = Array3::from_shape_fn(shape, |_| rng.gen());
    let audio_features: Tensor = audio_features.into();
    let shape = Dim([12, 1, 451, 512]);
    let kv_cache: Array4<f32> = Array4::from_shape_fn(shape, |_| rng.gen()).into();
    let kv_cache: Tensor = kv_cache.into();
    let offset: Tensor = Array::from_elem((), 0 as i64).into();
    let inputs = tvec!(
        tokens.into(),
        audio_features.into(),
        kv_cache.into(),
        offset.into()
    );

    let output = decoder.run(inputs).unwrap();
    println!("{:?}", output.inline_size());
}

I'm using version from the main branch and got error:

Caused by:
    0: in output_facts invocation for /Add_1
    1: Can not broadcast shapes a:tokens_dynamic_axes_1,tokens_dynamic_axes_2,512,F32 b:1,<Sym0>,512,F32', src/main.rs:14:10
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

If I use 0.20.18 version then I get error:

Caused by:
    0: Infering facts
    1: Applying rule outputs[0].shape == inputs[0].shape
    2: Unifying shapes 12,ScatterNDoutput_kv_cache_dim_1,451,512 and 12,kv_cache_dynamic_axes_1,451,512
    3: Impossible to unify Sym(ScatterNDoutput_kv_cache_dim_1) with Sym(kv_cache_dynamic_axes_1).', src/main.rs:14:10
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

weights link: https://www.dropbox.com/scl/fi/sl4rajtlsy0sfmpavkbtl/decoder_base_fix_kv_cache.onnx?rlkey=q2i2ju81t6vb7twwimf47ct1u&dl=0

without calling into_optimized(), reading weights works, but inference runs indefinitely :)

I would be grateful if you could tell me what I need to fix.

kali commented 1 year ago

Hey, what happened here ? I was actually having a look at your model... and it's a pretty challenging one. There are multiple places where tract could handle things better, and I have actually made a couple of fixes.

Did you discovered something wrong with the model in the meantime that I need to be aware of ?

igor-yusupov commented 1 year ago

@kali, I found a couple of points how to improve the model so that it can be run with tract and decided to try to finish it myself first so as not to distract you :)

now I will upload new weights and indicate why the model is not loaded in tract. wait a few minutes please.

igor-yusupov commented 1 year ago

@kali new weights link: https://drive.google.com/file/d/1xDFw9Q4-CIeKgdMjo_88jx7rzZmPBFGt/view?usp=sharing

an error is currently being raised:

Caused by:
    0: in output_facts invocation for /decoder/Add_1
    1: Can not broadcast shapes a:batch_size,token_len,512,F32 b:1,<Sym0>,512,F32', src/main.rs:16:10
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

it's called at the locations where the slice is taken.

the python code looks like this:

x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]

offset is the parameter that is applied to the input during inference.

also the model causes an error when reading:

thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Translating node #54 "/decoder/blocks.0/attn/Range" Range ToTypedTranslator

Caused by:
    Range needs fixed inputs', src/main.rs:16:10

this place also uses slices

igor-yusupov commented 1 year ago

but the previous weights are also correct, you can use them as a reference.

igor-yusupov commented 1 year ago

@kali Is it possible to run this model with tract? I will be glad to help, if there are any questions about model layers, I am ready to answer.

kali commented 1 year ago

One thing that makes this model difficult is the way it kind of does a manual gather at the very beginning:

NUC 23/10 13:22 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml --  --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long
    Finished dev [unoptimized + debuginfo] target(s) in 0.07s
     Running `/home/kali/dev/sonos/tract/target/debug/tract --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long`
┏ 1 Source tokens
┃   * output fact #0: batch_size,token_len,I64 >2/1 >5/0 MODEL INPUT #0
┣┓
┃┣┻ 2 Gather /decoder/token_embedding/Gather
┃┃   * input fact  #0: 0/0> 51865,512,F32 -0.008529663, 0.020874023, -0.015792847, 0.022338867, -0.0206604, 0.018005371, -0.014289856, 0.0024528503, 0.015136719, 0.002670288, 0.021194458, -0.011421204...
┃┃     input fact  #1: 1/0> batch_size,token_len,I64
┃┃   * output fact #0: batch_size,token_len,512,F32 >14/0  /decoder/token_embedding/Gather_output_0
┃┃┏ 3 Source offset
┃┃┃   * output fact #0: I64 >4/0 >8/0 MODEL INPUT #3
┃┃┣┓
┃┃┃┣ 4 AddDims /decoder/Unsqueeze
┃┃┃┃   * input fact  #0: 3/0> I64
┃┃┃┃   * output fact #0: 1,I64 >13/1  /decoder/Unsqueeze_output_0
┗━━┓
┃┃┃┣ 5 Shape /decoder/Shape
┃┃┃┃   * input fact  #0: 1/0> batch_size,token_len,I64
┃┃┃┃   * output fact #0: 2,TDim batch_size, token_len >7/0  /decoder/Shape_output_0
┃┃┃┣┻ 7 Gather /decoder/Gather
┃┃┃┃   * input fact  #0: 5/0> 2,TDim batch_size, token_len
┃┃┃┃     input fact  #1: 6/0> ,I64 1
┃┃┃┃   * output fact #0: ,TDim token_len >8/1  /decoder/Gather_output_0
┃┗┓┃
┃┃┣┻ 8 Add /decoder/Add
┃┃┃   * input fact  #0: 3/0> I64
┃┃┃     input fact  #1: 7/0> ,TDim token_len
┃┃┃   * output fact #0: TDim >9/0  /decoder/Add_output_0
┃┃┣ 9 AddDims /decoder/Unsqueeze_1
┃┃┃   * input fact  #0: 8/0> TDim
┃┃┃   * output fact #0: 1,TDim >13/2  /decoder/Unsqueeze_1_output_0

┃
┃┣┻┻┻┻ 13 StridedSlice /decoder/Slice
┃┃   * input fact  #0: 10/0> 448,512,F32 0.004878998, -0.0050811768, 0.004398346, 0.0033092499, -0.0034942627, 0.00617218, 0.010375977, -0.018112183, -0.007205963, 0.009048462, -0.00023424625, 0.0020427704...
┃┃     input fact  #1: 4/0> 1,I64
┃┃     input fact  #2: 9/0> 1,TDim
┃┃     input fact  #3: 11/0> 1,I64 0
┃┃     input fact  #4: 12/0> 1,I64 1
┃┃   * output fact #0: ?,512,F32 >14/1  /decoder/Slice_output_0
┣┻ 14 Add /decoder/Add_1
    * input fact  #0: 2/0> batch_size,token_len,512,F32
      input fact  #1: 13/0> ?,512,F32
    * output fact #0: ..,F32  MODEL OUTPUT #0 /decoder/Add_1_output_0

The second input of /decoder/Slice (end) is an uncomputed TDim, so tract can not compute a meaningful output shape for this StridedSlice. If we backtract it, it comes from node 3 through node 8 and node 9. node 3 is a scalar input (offset) so we can try to give it a TDim value on which tract will be able to operate. With these (and a couple of patches to the command line) we get:

NUC 23/10 13:53 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml --  -v --constantize offset -i offset:TDim=Offset --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long
    Finished dev [unoptimized + debuginfo] target(s) in 0.09s
     Running `/home/kali/dev/sonos/tract/target/debug/tract -v --constantize offset -i 'offset:TDim=Offset' --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long`
[2023-10-23T11:54:01.790809955Z INFO  tract] Resource usage init: vsz:66998272 rsz:11763712 rszmax:49049600
[2023-10-23T11:54:01.791277650Z INFO  tract] Resource usage loaded framework (onnx): vsz:66998272 rsz:20283392 rszmax:49049600
[2023-10-23T11:54:01.903374841Z INFO  tract] Resource usage proto model loaded: vsz:385937408 rsz:340602880 rszmax:655974400
[2023-10-23T11:54:02.323843040Z INFO  tract::params] Model Fs("decoder.onnx") loaded
[2023-10-23T11:54:02.323911908Z INFO  tract] Resource usage model loaded: vsz:391233536 rsz:349589504 rszmax:666918912
[2023-10-23T11:54:02.326423648Z INFO  tract::params] Commuting #3 "offset" Source, fact:I64 into a const of ,TDim Offset
[2023-10-23T11:54:02.332996880Z INFO  tract::params] Will stop at analyse
[2023-10-23T11:54:02.333007671Z INFO  tract::params] Running 'analyse'
[2023-10-23T11:54:02.733644983Z INFO  tract] Resource usage after analyse: vsz:178794496 rsz:139341824 rszmax:666918912
[2023-10-23T11:54:02.733670757Z INFO  tract::params] Model ready
[2023-10-23T11:54:02.733676664Z INFO  tract] Resource usage model ready: vsz:178794496 rsz:139341824 rszmax:666918912
┏ 1 Source tokens
┃   * output fact #0: batch_size,token_len,I64 >2/1 >5/0 MODEL INPUT #0
┣┓
┃┣┻ 2 Gather /decoder/token_embedding/Gather
┃┃   * input fact  #0: 0/0> 51865,512,F32 -0.008529663, 0.020874023, -0.015792847, 0.022338867, -0.0206604, 0.018005371, -0.014289856, 0.0024528503, 0.015136719, 0.002670288, 0.021194458, -0.011421204...
┃┃     input fact  #1: 1/0> batch_size,token_len,I64
┃┃   * output fact #0: batch_size,token_len,512,F32 >14/0  /decoder/token_embedding/Gather_output_0
┃┃┣ 4 AddDims /decoder/Unsqueeze
┃┃┃   * input fact  #0: 3/0> ,TDim Offset
┃┃┃   * output fact #0: 1,TDim Offset >13/1  /decoder/Unsqueeze_output_0
┗━┓
┃┃┣ 5 Shape /decoder/Shape
┃┃┃   * input fact  #0: 1/0> batch_size,token_len,I64
┃┃┃   * output fact #0: 2,TDim batch_size, token_len >7/0  /decoder/Shape_output_0
┃┃┣┻ 7 Gather /decoder/Gather
┃┃┃   * input fact  #0: 5/0> 2,TDim batch_size, token_len
┃┃┃     input fact  #1: 6/0> ,I64 1
┃┃┃   * output fact #0: ,TDim token_len >8/1  /decoder/Gather_output_0
┃┃┣┻ 8 Add /decoder/Add
┃┃┃   * input fact  #0: 3/0> ,TDim Offset
┃┃┃     input fact  #1: 7/0> ,TDim token_len
┃┃┃   * output fact #0: ,TDim Offset+token_len >9/0  /decoder/Add_output_0
┃┃┣ 9 AddDims /decoder/Unsqueeze_1
┃┃┃   * input fact  #0: 8/0> ,TDim Offset+token_len
┃┃┃   * output fact #0: 1,TDim Offset+token_len >13/2  /decoder/Unsqueeze_1_output_0

┃
┃┣┻┻┻┻ 13 StridedSlice /decoder/Slice
┃┃   * input fact  #0: 10/0> 448,512,F32 0.004878998, -0.0050811768, 0.004398346, 0.0033092499, -0.0034942627, 0.00617218, 0.010375977, -0.018112183, -0.007205963, 0.009048462, -0.00023424625, 0.0020427704...
┃┃     input fact  #1: 4/0> 1,TDim Offset
┃┃     input fact  #2: 9/0> 1,TDim Offset+token_len
┃┃     input fact  #3: 11/0> 1,I64 0
┃┃     input fact  #4: 12/0> 1,I64 1
┃┃   * output fact #0: token_len,512,F32 >14/1  /decoder/Slice_output_0
┣┻ 14 Add /decoder/Add_1
    * input fact  #0: 2/0> batch_size,token_len,512,F32
      input fact  #1: 13/0> token_len,512,F32
    * output fact #0: batch_size,token_len,512,F32  MODEL OUTPUT #0 /decoder/Add_1_output_0
[2023-10-23T11:54:02.734734055Z INFO  tract] Resource usage done: vsz:71651328 rsz:34631680 rszmax:666918912

But we're not there yet. More to come.

kali commented 1 year ago

So after more fixes to the command line, I get there:

┣┻ 276 Add /decoder/blocks.0/attn/Add_10
    * input fact  #0: 264/0> batch_size,8,token_len,Offset+token_len,F32
      input fact  #1: 275/0> token_len,token_len,F32
    * output fact #0: ..,?  MODEL OUTPUT #0 /decoder/blocks.0/attn/Add_10_output_0
[2023-10-23T12:16:59.275900759Z ERROR tract] Error at stage analyse

    Caused by:
        0: Failed analyse for node #276 "/decoder/blocks.0/attn/Add_10" Add
        1: Failed analyse for node #276 "/decoder/blocks.0/attn/Add_10" Add
        2: Infering facts
        3: Applying rule WithRule { inputs[1].shape }
        4: Matching batch_size,8,token_len,Offset+token_len and token_len,token_len with numpy/onnx broadcast rules
        5: Invalid shape (broadcasting): Sym(token_len) is not compatible with Some(Add([Sym(Offset), Sym(token_len)])).

We are trying to add a tensor of shape batch_size,8,token_len,Offset+token_len and one of token_len,token_len. This can't work, unless Offset, the symbol I introduced previously, is zero, which is what your code sample does (setting it to 0 instead of a symbol). I think there is something bad in there already, because I assume this is meant to work with any value of Offset.

Now if I set Offset to zero instead (and more cli patches), I get a different analysis error.

[...]
[2023-10-23T12:38:25.866602343Z ERROR tract] Error at stage analyse

    Caused by:
        0: ModelBuildingError
        1: #205 "/decoder/blocks.0/attn/Range" Range has incomplete typing

And it looks like this:

NUC 23/10 14:39 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=0 --pass analyse --output-node "/decoder/blocks.0/attn/Range"  decoder.onnx dump  --io-long
    Finished dev [unoptimized + debuginfo] target(s) in 0.07s
     Running `/home/kali/dev/sonos/tract/target/debug/tract --constantize offset -i 'offset:TDim=0' --pass analyse --output-node /decoder/blocks.0/attn/Range decoder.onnx dump --io-long`
┏ 2 Source kv_cache
┃   * output fact #0: 12,batch_size,451,512,F32 >190/0 >199/0 >206/0 >218/0 >268/0 >275/0 MODEL INPUT #2
┣┓
┃┣ 199 Shape /decoder/blocks.0/attn/Shape_2
┃┃   * input fact  #0: 2/0> 12,batch_size,451,512,F32
┃┃   * output fact #0: 4,TDim 12, batch_size, 451, 512 >201/0  /decoder/blocks.0/attn/Shape_2_output_0
┃┣┻ 201 Gather /decoder/blocks.0/attn/Gather_2
┃┃   * input fact  #0: 199/0> 4,TDim 12, batch_size, 451, 512
┃┃     input fact  #1: 200/0> ,I64 1
┃┃   * output fact #0: ,TDim batch_size >202/0  /decoder/blocks.0/attn/Gather_2_output_0
┃┣ 202 onnx.Cast /decoder/blocks.0/attn/Cast
┃┃   * input fact  #0: 201/0> ,TDim batch_size
┃┃   * output fact #0: TDim >205/1  /decoder/blocks.0/attn/Cast_output_0
┃┣┻┻ 205 Range /decoder/blocks.0/attn/Range
┃┃   * input fact  #0: 203/0> ,I64 0
┃┃     input fact  #1: 202/0> TDim
┃┃     input fact  #2: 204/0> ,I64 1
┃┃   * output fact #0: ?,TDim >227/0 MODEL OUTPUT #0 /decoder/blocks.0/attn/Range_output_0

Here we have a cast from TDim to TDim which erase the batch_size, so the Range operator can't compute it's output shape. This is a stupid behaviour from the Cast operator. One more patch.

And more to come.

kali commented 1 year ago

So now the Cast behaves.

NUC 23/10 14:48 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=0 --pass analyse --output-node "/decoder/blocks.0/attn/Range"  decoder.onnx dump  --io-long
    Finished dev [unoptimized + debuginfo] target(s) in 0.07s
     Running `/home/kali/dev/sonos/tract/target/debug/tract --constantize offset -i 'offset:TDim=0' --pass analyse --output-node /decoder/blocks.0/attn/Range decoder.onnx dump --io-long`
┏ 2 Source kv_cache
┃   * output fact #0: 12,batch_size,451,512,F32 >190/0 >199/0 >206/0 >218/0 >268/0 >275/0 MODEL INPUT #2
┣┓
┃┣ 199 Shape /decoder/blocks.0/attn/Shape_2
┃┃   * input fact  #0: 2/0> 12,batch_size,451,512,F32
┃┃   * output fact #0: 4,TDim 12, batch_size, 451, 512 >201/0  /decoder/blocks.0/attn/Shape_2_output_0
┃┣┻ 201 Gather /decoder/blocks.0/attn/Gather_2
┃┃   * input fact  #0: 199/0> 4,TDim 12, batch_size, 451, 512
┃┃     input fact  #1: 200/0> ,I64 1
┃┃   * output fact #0: ,TDim batch_size >202/0  /decoder/blocks.0/attn/Gather_2_output_0
┃┣ 202 onnx.Cast /decoder/blocks.0/attn/Cast
┃┃   * input fact  #0: 201/0> ,TDim batch_size
┃┃   * output fact #0: ,TDim batch_size >205/1  /decoder/blocks.0/attn/Cast_output_0
┃┣┻┻ 205 Range /decoder/blocks.0/attn/Range
┃┃   * input fact  #0: 203/0> ,I64 0
┃┃     input fact  #1: 202/0> ,TDim batch_size
┃┃     input fact  #2: 204/0> ,I64 1
┃┃   * output fact #0: ..,? >227/0 MODEL OUTPUT #0 /decoder/blocks.0/attn/Range_output_0

And now Range is showing its limit as an implementation of Expansion. I am going to open a separate issue for this one.

kali commented 1 year ago

Heavy lifting was already done, so adding the Range patch here instead.

kali commented 1 year ago

Make Range stateful (temporarily ?).

And I fixed you example to mimick the command line actions (constantize offset to 0), because there is not one-liner to "constantize".

    let mut decoder = tract_onnx::onnx().model_for_path(format!("decoder.onnx"))?;
    let offset_input_node_id = decoder.inputs[3].node;
    decoder.node_mut(offset_input_node_id).op = Box::new(Const::new(rctensor0(0.to_dim())));
    decoder.node_mut(offset_input_node_id).outputs[0].fact = Default::default();
    decoder.inputs.remove(3);
    let decoder = decoder.into_optimized()?.into_runnable()?;

With this the program exectues to the end.

igor-yusupov commented 1 year ago

Can you tell how to import "Const" please?

kali commented 1 year ago

I think this is the one:

use tract_onnx::tract_hir::ops::konst::Const;
igor-yusupov commented 1 year ago

Yes, it works now. But do I understand correctly that by setting

decoder.node_mut(offset_input_node_id).op = Box::new(Const::new(rctensor0(0.to_dim())));
decoder.node_mut(offset_input_node_id).outputs[0].fact = Default::default();
decoder.inputs.remove(3);

you just freeze offset parameter to 0 value? What if I want to give different offset values to the model input?

kali commented 1 year ago

You're completely right. These three lines set permanently offset to zero. It's just a first step.

To move on, we need you're expertise on the network. There is an inconsistency when tract tried to compute the shapes for the network, and without knowing what the model is supposed to do, it's pretty hard for me to found out where it diverges.

So you should run the following command, and you will get an error. Node 440 fails to make sense with its inputs and outputs: "Invalid shape (broadcasting): Sym(token_len) is not compatible with Some(Add([Sym(Offset), Sym(token_len)]))". We are trying to broadcast a tensor that has one shape dimension of offset+token_len against one that hastoken_len. But it is not obvious where the error cause is: this is where I need you to have a look at the output of the command line and figure out where tract does something unexpected...


cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=Offset --pass analyse decoder.onnx dump  --io-long```
igor-yusupov commented 1 year ago

I found the part of the code where the error occurs

def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
      n_batch, n_ctx, n_state = q.shape
      scale = (n_state // self.n_head) ** -0.25
      q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
      k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
      v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

      qk = q @ k
      if mask is not None:
          qk = qk + mask[:n_ctx, :n_ctx]

      w = F.softmax(qk.float(), dim=-1).to(q.dtype)
      return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)

Specifically:

qk = qk + mask[:n_ctx, :n_ctx]

qk.shape = (batch_size, 8, token_len, token_len + offset) n_ctx = token_len

It works in 2 cases: 1) token_len > 1 and offset = 0. 2) token_len = 1 and offset >= 1. In this case, a number is simply added to the matrix here. Moreover, this number is equal to 0 😀

In other cases, it really won't work. It may be possible to solve this by changing the network architecture.

kali commented 1 year ago

Well, in that case, you can verify tract is also happy with case 2:

cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=Offset -i tokens:batch_size,1,i64 decoder.onnx dump  --io-long
igor-yusupov commented 1 year ago

Great! Then I'll check the network outputs and if everything matches, we can close the issue.

igor-yusupov commented 1 year ago

It seems to be impossible to run this model in both 2 cases? I mean without freezing offset parameter. If not, I guess I'll have to redesign the architecture of the model.

kali commented 1 year ago

tract will not accept the model if it can not "prove" the dimensions are valid. We have established tract can prove and load the model with either one of the cases: you can load the model if you give track the hints for condition 1, or the hints for condition 2. But tract model validation is not smart enough to handle complex rules like "tokens will be 1 OR offset will be 0".

So what you can do if you are anticipating a mix of inputs matching one or the other necessary conditions is load the model twice and pick at the right version depending on which of the two conditions the input checks right before running it.

Or... rework the design.

igor-yusupov commented 1 year ago

It seems that double loading the model wastes RAM. Yeah, I think it's definitely worth redesigning the model. Thank you very much for your responsiveness!