pykeio / ort

Fast ML inference & training for Rust with ONNX Runtime
https://ort.pyke.io/
Apache License 2.0
859 stars 100 forks source link

Session. run error #264

Closed AbhishekBose closed 2 months ago

AbhishekBose commented 2 months ago

For this code.

#[post("/predict")]
async fn predict(data: web::Data<AppState>, req: web::Json<PredictRequest>) -> impl Responder {
    let inputs = &req.texts;

    // Encode input strings.
    let encodings = data.tokenizer.encode_batch(inputs.clone(), false).unwrap();
    let padded_token_length = encodings[0].len();

    let ids: Vec<i64> = encodings.iter().flat_map(|e| e.get_ids().iter().map(|i| *i as i64)).collect();
    let mask: Vec<i64> = encodings.iter().flat_map(|e| e.get_attention_mask().iter().map(|i| *i as i64)).collect();

    let a_ids = Array2::from_shape_vec([inputs.len(), padded_token_length], ids).unwrap();
    let a_mask = Array2::from_shape_vec([inputs.len(), padded_token_length], mask).unwrap();

    // Run the model.
    let outputs = data.session.run(ort::inputs![a_ids, a_mask]).unwrap();

I am getting the following error

error[E0277]: the trait bound `SessionInputs<'_, '_, _>: std::convert::From<Result<[SessionInputValue<'_>; 2], ort::Error>>` is not satisfied
   --> src/main.rs:48:36
    |
48  |     let outputs = data.session.run(ort::inputs![a_ids, a_mask]).unwrap();
    |                                --- ^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `std::convert::From<Result<[SessionInputValue<'_>; 2], ort::Error>>` is not implemented for `SessionInputs<'_, '_, _>`, which is required by `Result<[SessionInputValue<'_>; 2], ort::Error>: Into<SessionInputs<'_, '_, _>>`
    |                                |
    |                                required by a bound introduced by this call
    |
    = help: the following other types implement trait `std::convert::From<T>`:
              <SessionInputs<'i, 'v, N> as std::convert::From<[SessionInputValue<'v>; N]>>
              <SessionInputs<'i, 'v> as std::convert::From<&'i [SessionInputValue<'v>]>>
              <SessionInputs<'i, 'v> as std::convert::From<HashMap<K, V>>>
              <SessionInputs<'i, 'v> as std::convert::From<Vec<(K, V)>>>
    = note: required for `Result<[SessionInputValue<'_>; 2], ort::Error>` to implement `Into<SessionInputs<'_, '_, _>>`
note: required by a bound in `ort::Session::run`

I was following the example given here in this example

p1atdev commented 2 months ago

I had the same problem. It's due to the version of ndarray. Using 0.15 instead of 0.16 works for me.

[dependencies]
ort = { version = "2.0.0-rc.4" }
ndarray = "0.15"
decahedron1 commented 2 months ago

ort::inputs! returns a Result, so it should be session.run(ort::inputs![a_ids, a_mask].unwrap()).