Closed igor-yusupov closed 1 year ago
Hey, what happened here ? I was actually having a look at your model... and it's a pretty challenging one. There are multiple places where tract could handle things better, and I have actually made a couple of fixes.
Did you discovered something wrong with the model in the meantime that I need to be aware of ?
@kali, I found a couple of points how to improve the model so that it can be run with tract and decided to try to finish it myself first so as not to distract you :)
now I will upload new weights and indicate why the model is not loaded in tract. wait a few minutes please.
@kali new weights link: https://drive.google.com/file/d/1xDFw9Q4-CIeKgdMjo_88jx7rzZmPBFGt/view?usp=sharing
an error is currently being raised:
Caused by:
0: in output_facts invocation for /decoder/Add_1
1: Can not broadcast shapes a:batch_size,token_len,512,F32 b:1,<Sym0>,512,F32', src/main.rs:16:10
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
it's called at the locations where the slice is taken.
the python code looks like this:
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
offset is the parameter that is applied to the input during inference.
also the model causes an error when reading:
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Translating node #54 "/decoder/blocks.0/attn/Range" Range ToTypedTranslator
Caused by:
Range needs fixed inputs', src/main.rs:16:10
this place also uses slices
but the previous weights are also correct, you can use them as a reference.
@kali Is it possible to run this model with tract? I will be glad to help, if there are any questions about model layers, I am ready to answer.
One thing that makes this model difficult is the way it kind of does a manual gather at the very beginning:
NUC 23/10 13:22 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml -- --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long
Finished dev [unoptimized + debuginfo] target(s) in 0.07s
Running `/home/kali/dev/sonos/tract/target/debug/tract --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long`
┏ 1 Source tokens
┃ * output fact #0: batch_size,token_len,I64 >2/1 >5/0 MODEL INPUT #0
┣┓
┃┣┻ 2 Gather /decoder/token_embedding/Gather
┃┃ * input fact #0: 0/0> 51865,512,F32 -0.008529663, 0.020874023, -0.015792847, 0.022338867, -0.0206604, 0.018005371, -0.014289856, 0.0024528503, 0.015136719, 0.002670288, 0.021194458, -0.011421204...
┃┃ input fact #1: 1/0> batch_size,token_len,I64
┃┃ * output fact #0: batch_size,token_len,512,F32 >14/0 /decoder/token_embedding/Gather_output_0
┃┃┏ 3 Source offset
┃┃┃ * output fact #0: I64 >4/0 >8/0 MODEL INPUT #3
┃┃┣┓
┃┃┃┣ 4 AddDims /decoder/Unsqueeze
┃┃┃┃ * input fact #0: 3/0> I64
┃┃┃┃ * output fact #0: 1,I64 >13/1 /decoder/Unsqueeze_output_0
┗━━┓
┃┃┃┣ 5 Shape /decoder/Shape
┃┃┃┃ * input fact #0: 1/0> batch_size,token_len,I64
┃┃┃┃ * output fact #0: 2,TDim batch_size, token_len >7/0 /decoder/Shape_output_0
┃┃┃┣┻ 7 Gather /decoder/Gather
┃┃┃┃ * input fact #0: 5/0> 2,TDim batch_size, token_len
┃┃┃┃ input fact #1: 6/0> ,I64 1
┃┃┃┃ * output fact #0: ,TDim token_len >8/1 /decoder/Gather_output_0
┃┗┓┃
┃┃┣┻ 8 Add /decoder/Add
┃┃┃ * input fact #0: 3/0> I64
┃┃┃ input fact #1: 7/0> ,TDim token_len
┃┃┃ * output fact #0: TDim >9/0 /decoder/Add_output_0
┃┃┣ 9 AddDims /decoder/Unsqueeze_1
┃┃┃ * input fact #0: 8/0> TDim
┃┃┃ * output fact #0: 1,TDim >13/2 /decoder/Unsqueeze_1_output_0
┃
┃┣┻┻┻┻ 13 StridedSlice /decoder/Slice
┃┃ * input fact #0: 10/0> 448,512,F32 0.004878998, -0.0050811768, 0.004398346, 0.0033092499, -0.0034942627, 0.00617218, 0.010375977, -0.018112183, -0.007205963, 0.009048462, -0.00023424625, 0.0020427704...
┃┃ input fact #1: 4/0> 1,I64
┃┃ input fact #2: 9/0> 1,TDim
┃┃ input fact #3: 11/0> 1,I64 0
┃┃ input fact #4: 12/0> 1,I64 1
┃┃ * output fact #0: ?,512,F32 >14/1 /decoder/Slice_output_0
┣┻ 14 Add /decoder/Add_1
* input fact #0: 2/0> batch_size,token_len,512,F32
input fact #1: 13/0> ?,512,F32
* output fact #0: ..,F32 MODEL OUTPUT #0 /decoder/Add_1_output_0
The second input of /decoder/Slice
(end) is an uncomputed TDim, so tract can not compute a meaningful output shape for this StridedSlice. If we backtract it, it comes from node 3 through node 8 and node 9. node 3 is a scalar input (offset
) so we can try to give it a TDim value on which tract will be able to operate. With these (and a couple of patches to the command line) we get:
NUC 23/10 13:53 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml -- -v --constantize offset -i offset:TDim=Offset --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long
Finished dev [unoptimized + debuginfo] target(s) in 0.09s
Running `/home/kali/dev/sonos/tract/target/debug/tract -v --constantize offset -i 'offset:TDim=Offset' --pass analyse --output-node /decoder/Add_1 decoder.onnx --partial dump --io-long`
[2023-10-23T11:54:01.790809955Z INFO tract] Resource usage init: vsz:66998272 rsz:11763712 rszmax:49049600
[2023-10-23T11:54:01.791277650Z INFO tract] Resource usage loaded framework (onnx): vsz:66998272 rsz:20283392 rszmax:49049600
[2023-10-23T11:54:01.903374841Z INFO tract] Resource usage proto model loaded: vsz:385937408 rsz:340602880 rszmax:655974400
[2023-10-23T11:54:02.323843040Z INFO tract::params] Model Fs("decoder.onnx") loaded
[2023-10-23T11:54:02.323911908Z INFO tract] Resource usage model loaded: vsz:391233536 rsz:349589504 rszmax:666918912
[2023-10-23T11:54:02.326423648Z INFO tract::params] Commuting #3 "offset" Source, fact:I64 into a const of ,TDim Offset
[2023-10-23T11:54:02.332996880Z INFO tract::params] Will stop at analyse
[2023-10-23T11:54:02.333007671Z INFO tract::params] Running 'analyse'
[2023-10-23T11:54:02.733644983Z INFO tract] Resource usage after analyse: vsz:178794496 rsz:139341824 rszmax:666918912
[2023-10-23T11:54:02.733670757Z INFO tract::params] Model ready
[2023-10-23T11:54:02.733676664Z INFO tract] Resource usage model ready: vsz:178794496 rsz:139341824 rszmax:666918912
┏ 1 Source tokens
┃ * output fact #0: batch_size,token_len,I64 >2/1 >5/0 MODEL INPUT #0
┣┓
┃┣┻ 2 Gather /decoder/token_embedding/Gather
┃┃ * input fact #0: 0/0> 51865,512,F32 -0.008529663, 0.020874023, -0.015792847, 0.022338867, -0.0206604, 0.018005371, -0.014289856, 0.0024528503, 0.015136719, 0.002670288, 0.021194458, -0.011421204...
┃┃ input fact #1: 1/0> batch_size,token_len,I64
┃┃ * output fact #0: batch_size,token_len,512,F32 >14/0 /decoder/token_embedding/Gather_output_0
┃┃┣ 4 AddDims /decoder/Unsqueeze
┃┃┃ * input fact #0: 3/0> ,TDim Offset
┃┃┃ * output fact #0: 1,TDim Offset >13/1 /decoder/Unsqueeze_output_0
┗━┓
┃┃┣ 5 Shape /decoder/Shape
┃┃┃ * input fact #0: 1/0> batch_size,token_len,I64
┃┃┃ * output fact #0: 2,TDim batch_size, token_len >7/0 /decoder/Shape_output_0
┃┃┣┻ 7 Gather /decoder/Gather
┃┃┃ * input fact #0: 5/0> 2,TDim batch_size, token_len
┃┃┃ input fact #1: 6/0> ,I64 1
┃┃┃ * output fact #0: ,TDim token_len >8/1 /decoder/Gather_output_0
┃┃┣┻ 8 Add /decoder/Add
┃┃┃ * input fact #0: 3/0> ,TDim Offset
┃┃┃ input fact #1: 7/0> ,TDim token_len
┃┃┃ * output fact #0: ,TDim Offset+token_len >9/0 /decoder/Add_output_0
┃┃┣ 9 AddDims /decoder/Unsqueeze_1
┃┃┃ * input fact #0: 8/0> ,TDim Offset+token_len
┃┃┃ * output fact #0: 1,TDim Offset+token_len >13/2 /decoder/Unsqueeze_1_output_0
┃
┃┣┻┻┻┻ 13 StridedSlice /decoder/Slice
┃┃ * input fact #0: 10/0> 448,512,F32 0.004878998, -0.0050811768, 0.004398346, 0.0033092499, -0.0034942627, 0.00617218, 0.010375977, -0.018112183, -0.007205963, 0.009048462, -0.00023424625, 0.0020427704...
┃┃ input fact #1: 4/0> 1,TDim Offset
┃┃ input fact #2: 9/0> 1,TDim Offset+token_len
┃┃ input fact #3: 11/0> 1,I64 0
┃┃ input fact #4: 12/0> 1,I64 1
┃┃ * output fact #0: token_len,512,F32 >14/1 /decoder/Slice_output_0
┣┻ 14 Add /decoder/Add_1
* input fact #0: 2/0> batch_size,token_len,512,F32
input fact #1: 13/0> token_len,512,F32
* output fact #0: batch_size,token_len,512,F32 MODEL OUTPUT #0 /decoder/Add_1_output_0
[2023-10-23T11:54:02.734734055Z INFO tract] Resource usage done: vsz:71651328 rsz:34631680 rszmax:666918912
But we're not there yet. More to come.
So after more fixes to the command line, I get there:
┣┻ 276 Add /decoder/blocks.0/attn/Add_10
* input fact #0: 264/0> batch_size,8,token_len,Offset+token_len,F32
input fact #1: 275/0> token_len,token_len,F32
* output fact #0: ..,? MODEL OUTPUT #0 /decoder/blocks.0/attn/Add_10_output_0
[2023-10-23T12:16:59.275900759Z ERROR tract] Error at stage analyse
Caused by:
0: Failed analyse for node #276 "/decoder/blocks.0/attn/Add_10" Add
1: Failed analyse for node #276 "/decoder/blocks.0/attn/Add_10" Add
2: Infering facts
3: Applying rule WithRule { inputs[1].shape }
4: Matching batch_size,8,token_len,Offset+token_len and token_len,token_len with numpy/onnx broadcast rules
5: Invalid shape (broadcasting): Sym(token_len) is not compatible with Some(Add([Sym(Offset), Sym(token_len)])).
We are trying to add a tensor of shape batch_size,8,token_len,Offset+token_len
and one of token_len,token_len
. This can't work, unless Offset, the symbol I introduced previously, is zero, which is what your code sample does (setting it to 0 instead of a symbol). I think there is something bad in there already, because I assume this is meant to work with any value of Offset.
Now if I set Offset to zero instead (and more cli patches), I get a different analysis error.
[...]
[2023-10-23T12:38:25.866602343Z ERROR tract] Error at stage analyse
Caused by:
0: ModelBuildingError
1: #205 "/decoder/blocks.0/attn/Range" Range has incomplete typing
And it looks like this:
NUC 23/10 14:39 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=0 --pass analyse --output-node "/decoder/blocks.0/attn/Range" decoder.onnx dump --io-long
Finished dev [unoptimized + debuginfo] target(s) in 0.07s
Running `/home/kali/dev/sonos/tract/target/debug/tract --constantize offset -i 'offset:TDim=0' --pass analyse --output-node /decoder/blocks.0/attn/Range decoder.onnx dump --io-long`
┏ 2 Source kv_cache
┃ * output fact #0: 12,batch_size,451,512,F32 >190/0 >199/0 >206/0 >218/0 >268/0 >275/0 MODEL INPUT #2
┣┓
┃┣ 199 Shape /decoder/blocks.0/attn/Shape_2
┃┃ * input fact #0: 2/0> 12,batch_size,451,512,F32
┃┃ * output fact #0: 4,TDim 12, batch_size, 451, 512 >201/0 /decoder/blocks.0/attn/Shape_2_output_0
┃┣┻ 201 Gather /decoder/blocks.0/attn/Gather_2
┃┃ * input fact #0: 199/0> 4,TDim 12, batch_size, 451, 512
┃┃ input fact #1: 200/0> ,I64 1
┃┃ * output fact #0: ,TDim batch_size >202/0 /decoder/blocks.0/attn/Gather_2_output_0
┃┣ 202 onnx.Cast /decoder/blocks.0/attn/Cast
┃┃ * input fact #0: 201/0> ,TDim batch_size
┃┃ * output fact #0: TDim >205/1 /decoder/blocks.0/attn/Cast_output_0
┃┣┻┻ 205 Range /decoder/blocks.0/attn/Range
┃┃ * input fact #0: 203/0> ,I64 0
┃┃ input fact #1: 202/0> TDim
┃┃ input fact #2: 204/0> ,I64 1
┃┃ * output fact #0: ?,TDim >227/0 MODEL OUTPUT #0 /decoder/blocks.0/attn/Range_output_0
Here we have a cast from TDim to TDim which erase the batch_size, so the Range operator can't compute it's output shape. This is a stupid behaviour from the Cast operator. One more patch.
And more to come.
So now the Cast behaves.
NUC 23/10 14:48 ~/dev/sonos/tract/issue-1202% cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=0 --pass analyse --output-node "/decoder/blocks.0/attn/Range" decoder.onnx dump --io-long
Finished dev [unoptimized + debuginfo] target(s) in 0.07s
Running `/home/kali/dev/sonos/tract/target/debug/tract --constantize offset -i 'offset:TDim=0' --pass analyse --output-node /decoder/blocks.0/attn/Range decoder.onnx dump --io-long`
┏ 2 Source kv_cache
┃ * output fact #0: 12,batch_size,451,512,F32 >190/0 >199/0 >206/0 >218/0 >268/0 >275/0 MODEL INPUT #2
┣┓
┃┣ 199 Shape /decoder/blocks.0/attn/Shape_2
┃┃ * input fact #0: 2/0> 12,batch_size,451,512,F32
┃┃ * output fact #0: 4,TDim 12, batch_size, 451, 512 >201/0 /decoder/blocks.0/attn/Shape_2_output_0
┃┣┻ 201 Gather /decoder/blocks.0/attn/Gather_2
┃┃ * input fact #0: 199/0> 4,TDim 12, batch_size, 451, 512
┃┃ input fact #1: 200/0> ,I64 1
┃┃ * output fact #0: ,TDim batch_size >202/0 /decoder/blocks.0/attn/Gather_2_output_0
┃┣ 202 onnx.Cast /decoder/blocks.0/attn/Cast
┃┃ * input fact #0: 201/0> ,TDim batch_size
┃┃ * output fact #0: ,TDim batch_size >205/1 /decoder/blocks.0/attn/Cast_output_0
┃┣┻┻ 205 Range /decoder/blocks.0/attn/Range
┃┃ * input fact #0: 203/0> ,I64 0
┃┃ input fact #1: 202/0> ,TDim batch_size
┃┃ input fact #2: 204/0> ,I64 1
┃┃ * output fact #0: ..,? >227/0 MODEL OUTPUT #0 /decoder/blocks.0/attn/Range_output_0
And now Range is showing its limit as an implementation of Expansion. I am going to open a separate issue for this one.
Heavy lifting was already done, so adding the Range patch here instead.
Make Range stateful (temporarily ?).
And I fixed you example to mimick the command line actions (constantize offset to 0), because there is not one-liner to "constantize".
let mut decoder = tract_onnx::onnx().model_for_path(format!("decoder.onnx"))?;
let offset_input_node_id = decoder.inputs[3].node;
decoder.node_mut(offset_input_node_id).op = Box::new(Const::new(rctensor0(0.to_dim())));
decoder.node_mut(offset_input_node_id).outputs[0].fact = Default::default();
decoder.inputs.remove(3);
let decoder = decoder.into_optimized()?.into_runnable()?;
With this the program exectues to the end.
Can you tell how to import "Const" please?
I think this is the one:
use tract_onnx::tract_hir::ops::konst::Const;
Yes, it works now. But do I understand correctly that by setting
decoder.node_mut(offset_input_node_id).op = Box::new(Const::new(rctensor0(0.to_dim())));
decoder.node_mut(offset_input_node_id).outputs[0].fact = Default::default();
decoder.inputs.remove(3);
you just freeze offset parameter to 0 value? What if I want to give different offset values to the model input?
You're completely right. These three lines set permanently offset to zero. It's just a first step.
To move on, we need you're expertise on the network. There is an inconsistency when tract tried to compute the shapes for the network, and without knowing what the model is supposed to do, it's pretty hard for me to found out where it diverges.
So you should run the following command, and you will get an error. Node 440 fails to make sense with its inputs and outputs: "Invalid shape (broadcasting): Sym(token_len) is not compatible with Some(Add([Sym(Offset), Sym(token_len)]))". We are trying to broadcast a tensor that has one shape dimension of offset+token_len
against one that hastoken_len
. But it is not obvious where the error cause is: this is where I need you to have a look at the output of the command line and figure out where tract does something unexpected...
cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=Offset --pass analyse decoder.onnx dump --io-long```
I found the part of the code where the error occurs
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
Specifically:
qk = qk + mask[:n_ctx, :n_ctx]
qk.shape = (batch_size, 8, token_len, token_len + offset) n_ctx = token_len
It works in 2 cases: 1) token_len > 1 and offset = 0. 2) token_len = 1 and offset >= 1. In this case, a number is simply added to the matrix here. Moreover, this number is equal to 0 😀
In other cases, it really won't work. It may be possible to solve this by changing the network architecture.
Well, in that case, you can verify tract is also happy with case 2:
cargo run -p tract --manifest-path=../cli/Cargo.toml -- --constantize offset -i offset:TDim=Offset -i tokens:batch_size,1,i64 decoder.onnx dump --io-long
Great! Then I'll check the network outputs and if everything matches, we can close the issue.
It seems to be impossible to run this model in both 2 cases? I mean without freezing offset parameter. If not, I guess I'll have to redesign the architecture of the model.
tract will not accept the model if it can not "prove" the dimensions are valid. We have established tract can prove and load the model with either one of the cases: you can load the model if you give track the hints for condition 1, or the hints for condition 2. But tract model validation is not smart enough to handle complex rules like "tokens will be 1 OR offset will be 0".
So what you can do if you are anticipating a mix of inputs matching one or the other necessary conditions is load the model twice and pick at the right version depending on which of the two conditions the input checks right before running it.
Or... rework the design.
It seems that double loading the model wastes RAM. Yeah, I think it's definitely worth redesigning the model. Thank you very much for your responsiveness!
Tried to run the model, but can't get it to read.
I'm using version from the main branch and got error:
If I use 0.20.18 version then I get error:
weights link: https://www.dropbox.com/scl/fi/sl4rajtlsy0sfmpavkbtl/decoder_base_fix_kv_cache.onnx?rlkey=q2i2ju81t6vb7twwimf47ct1u&dl=0
without calling into_optimized(), reading weights works, but inference runs indefinitely :)
I would be grateful if you could tell me what I need to fix.