pykeio / ort

Fast ML inference & training for Rust with ONNX Runtime
https://ort.pyke.io/
Apache License 2.0
859 stars 100 forks source link

Problem with ort and WASM #260

Closed all-c-a-p-s closed 2 months ago

all-c-a-p-s commented 2 months ago

I switched from the tract-onnx crate to ort to load a .onnx model in Rust. I followed instructions found at the article https://ort.pyke.io/setup/webassembly to try to integrate ort with wasm in my project. However, I still get an error in the console of the browser when I try to use trunk serve to build the project.

Here is the wasm-related error:

TypeError: Failed to execute 'decode' on 'TextDecoder': The encoded data was not valid.
    at getStringFromWasm0 (graffiti-12c00882611d1559.js:18:30)
    at imports.wbg.__wbg_error_d2d279fddc1936c2 (graffiti-12c00882611d1559.js:301:27)
    at graffiti-95270a5d58d49849.wasm.eframe::web::panic_handler::error::he4aea513b62c29d8 (graffiti-12c00882611d1559_bg.wasm:0xb9a778)
    at graffiti-95270a5d58d49849.wasm.eframe::web::panic_handler::PanicHandler::install::{{closure}}::hd32e90139075ff91 (graffiti-12c00882611d1559_bg.wasm:0x676f90)
    at graffiti-95270a5d58d49849.wasm.std::panicking::rust_panic_with_hook::h6731baa78621a747 (graffiti-12c00882611d1559_bg.wasm:0xac87d0)
    at graffiti-95270a5d58d49849.wasm.std::panicking::begin_panic_handler::{{closure}}::hb6cd8464ed39ae71 (graffiti-12c00882611d1559_bg.wasm:0xb5d383)
    at graffiti-95270a5d58d49849.wasm.std::sys_common::backtrace::__rust_end_short_backtrace::hbdf3ddeb21a1e747 (graffiti-12c00882611d1559_bg.wasm:0xcb007f)
    at graffiti-95270a5d58d49849.wasm.rust_begin_unwind (graffiti-12c00882611d1559_bg.wasm:0xc5d91f)
    at graffiti-95270a5d58d49849.wasm.core::panicking::panic_fmt::h5c7ce52813e94bcd (graffiti-12c00882611d1559_bg.wasm:0xc7026e)
    at graffiti-95270a5d58d49849.wasm.core::cell::panic_already_borrowed::h18b8189a0fdd8b58 (graffiti-12c00882611d1559_bg.wasm:0xc33747)

Here is all of my ort-related code:

use std::collections::HashMap;

use ort::{Session, Tensor};

...

//include NN data at compile time
static GRADE_MODEL_BYTES: &[u8] = include_bytes!("/Users/seba/rs/graffiti/models/custom_model.ort");
static ROUTESET_MODEL_BYTES: &[u8] = include_bytes!("/Users/seba/rs/graffiti/models/routeset/routeset.ort");

pub fn run_model(
    start_holds: Vec<String>,
    finish_holds: Vec<String>,
    intermediate_holds: Vec<String>,
) -> ort::Result<String> {
    let mut holds_data: Vec<f32> = vec![0.0; 198];

    ...

    #[cfg(target_arch = "wasm32")]
    ort::wasm::initialize();

    let session = Session::builder()?.commit_from_memory(GRADE_MODEL_BYTES)?;

    let input_holds = Tensor::from_array(([1, 198], holds_data.clone().into_boxed_slice()))?;
    let mut inputs = HashMap::new();
    inputs.insert("input_layer", input_holds);

    let outputs = session.run(inputs)?;

    let mut probabilities: Vec<f32> = Vec::new();

    for (_, output_value) in outputs.iter() {
        probabilities = output_value
            .to_owned()
            .try_extract_tensor::<f32>()?
            .iter()
            .cloned()
            .collect::<Vec<f32>>();
    }

    let mut max: f32 = 0.0;
    let mut most_likely_grade = 4;

    ...
}

pub fn run_routeset_model(
    start_holds: &Vec<String>,
    finish_holds: &Vec<String>,
    intermediate_holds: &Vec<String>,
    grade: f32,
) -> ort::Result<Option<String>> {
    let mut holds_data: Vec<Vec<f32>> = vec![vec![0.0f32; 11]; 18];
    ...

    #[cfg(target_arch = "wasm32")]
    ort::wasm::initialize();

    let session = Session::builder()?.commit_from_memory(ROUTESET_MODEL_BYTES)?;

    let input_vector = holds_data.iter().flatten().cloned().collect::<Vec<f32>>();
    let input_holds = Tensor::from_array(([1, 18, 11], input_vector.clone().into_boxed_slice()))?;
    let input_grade = Tensor::from_array(([1, 1], vec![grade].into_boxed_slice()))?;

    let mut inputs = HashMap::new();
    inputs.insert("input_holds", input_holds);
    inputs.insert("input_grades", input_grade);

    let outputs = session.run(inputs)?;

    let mut probabilities: Vec<f32> = Vec::new();
    for (_, output_value) in outputs.iter() {
        probabilities = output_value
            .to_owned()
            .try_extract_tensor::<f32>()?
            .iter()
            .cloned()
            .collect::<Vec<f32>>();
    }

    ...
}

pub fn generate_route(
    start_holds: Vec<String>,
    finish_holds: Vec<String>,
    intermediate_holds: Vec<String>,
    grade: usize,
) -> (Vec<String>, Vec<String>, Vec<String>) {
    #[cfg(target_arch = "wasm32")]
    ort::wasm::initialize();
    ...
    let mut next_hold = run_routeset_model(
        &start_holds,
        &finish_holds,
        &intermediate_holds,
        grade as f32,
    )
    .expect("failed to run model");
    while next_hold.is_some() {
        ...
        next_hold = run_routeset_model(&s, &f, &i, grade as f32).expect("failed to run model");
    }
    ...
}

Here is my Cargo.toml: (I don't think any of the other dependencies are incompatible with WASM):

[package]
name = "graffiti"
version = "0.1.0"
edition = "2021"
description = "AI for climbing routesetting"
authors = ["Sebastiano Rebonato-Scott"]
license = "MIT"
repository = "https://github.com/all-c-a-p-s/Graffiti"

[dependencies]
console_error_panic_hook = "0.1.7"
eframe = "0.28.1"
egui = "0.28.1"
egui_extras = { version = "0.28.1", features = ["all_loaders"] }
getrandom = { version = "0.2", features = ["js"] }
image = { version = "0.25.2", features = ["jpeg", "png"] }
log = "0.4.22"
ndarray = "0.16.0"
ort = "2.0.0-rc.4"
wasm-bindgen = "0.2.92"
web-sys = "0.3.69"

[package.metadata.bundle]
name = "graffiti"
identifier = "io.github.all-c-a-p-s.graffiti"

[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen-futures = "0.4"

[lints.clippy]
all = "warn"

I have also tried compiling to WASM using wasm-pack instead of trunk, but I got the same error. Before I switched to ort from tract-onnx, the WASM was working fine, which is my reason for thinking that ort is the cause of the problem. Sorry for the very long post and thanks for your time.

decahedron1 commented 2 months ago

I can no longer support WASM; it'll be removed soon. Sorry. I suggest you switch back to tract.