tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.48k stars 419 forks source link

`train-minimal` doesn't work on wasm after async read #865

Closed AlexErrant closed 11 months ago

AlexErrant commented 11 months ago

I would like to run train-minimal in the browser. It compiled as of https://github.com/burn-rs/burn/pull/830, but not after https://github.com/burn-rs/burn/pull/833.

In particular, there's an into_scalar call here:

https://github.com/burn-rs/burn/blob/84e74df3b988baa753c82618ad23b0013acc02cb/burn-train/src/metric/acc.rs#L55

into_scalar has different signatures depending on the target_family:

https://github.com/burn-rs/burn/blob/84e74df3b988baa753c82618ad23b0013acc02cb/burn-tensor/src/tensor/api/numeric.rs#L12-L34

Namely, one function is async, and one is not.

I see that "training isn't fully supported on wasm anyways". Really my goal is to make https://github.com/open-spaced-repetition/fsrs-rs compile to wasm, including its ability to train. Is this something that can be resolved by a motivated Rust noobie? (Me.) Despite staring at #833 for a while I still don't understand how Reader prevents the virality of async.

AlexErrant commented 11 months ago

Async traits just landed in main. This avoids async-trait's heap allocations:

you stop paying the cost of a heap allocation per async fn method call

Is this a viable solution?

nathanielsimard commented 11 months ago

Training on wasm is tricky; a lot of the code inside burn-train heavily uses std and the file system. The Reader type is really targeting backend developers rather than actual burn users. This avoids having to add async in every trait/function signature. The problem with async on wasm is that you can never block! This is extremely constraining since async will spread everywhere. This isn't really an issue if you use async with a runtime like Tokio.

Now, this doesn't mean that you can't use Burn to train models on wasm, but I'm unsure if you should be using burn-train. The project uses a lot of threads and channels to do async work like rendering the terminal UI and computing the metrics, which isn't really optimal when using an environment that doesn't support threading.

If you have more information on how you plan to train on wasm, it would be useful. I imagine it worked before because you used a backend that doesn't require async (like ndarray), but using wgpu wouldn't possibly work.

AlexErrant commented 11 months ago

how you plan to train on wasm

I'm building a wasm/js wrapper for https://github.com/open-spaced-repetition/fsrs-rs here, and it currently only exposes two functions, neither of which are training related. To be explicit: fsrs-wasm compiled to wasm as of #830; whether it works at runtime is an open question. I don't even know the status of the training feature on fsrs-rs; @dae @L-M-Sherlock please feel free to comment.

I can't speak for others in this issue, but for my purposes, performance is a nice to have, not a necessity. (I kinda gave up on "optimal" when I started learning Javascript ( ._.) If end-users want better perf I would be thrilled to point them to a non-browser-based solution. I see your "just because you can doesn't mean you should" and agree. I'm currently using ndarray, and while w(eb)GPU would be nice I don't need it.

burn-train heavily uses std and the file system

Hm, I see that in fsrs-rs here. I'm a little confused how/why fsrs-wasm compiled then, I would've expected wasm-pack to throw... Still, browsers now have OPFS which "offers low-level, byte-by-byte file access" and "also has a set of synchronous calls". I'm willing to explore if that's sufficient to get burn-train working in the browser. I believe train-minimal, used by fsrs-rs, doesn't include the TUI or metrics so that's less surface area.

L-M-Sherlock commented 11 months ago

Even if the burn supports to train the model in wasm, how to pass the dataset into wasm is still a tricky problem. The wasm_bindgen::convert::FromWasmAbi doesn't support complex object. We may need to pass json string into the wasm module. And it's performance is also unknown (I guess it's very slow).

AlexErrant commented 11 months ago

I was thinking about passing a (Shared)ArrayBuffer back and forth: https://developer.mozilla.org/en-US/docs/WebAssembly/JavaScript_interface/Memory

If you want to access the memory created in JS from Wasm or vice versa, you can pass a reference to the memory from one side to the other.

No idea about its perf, but it'll be better than JSON.

L-M-Sherlock commented 11 months ago

Yeah. In my opinion, it's pretty nice to train FSRS in wasm. It supports a broader scope of devices. But I guess it requires huge development and couldn't be implemented in a short time.

AlexErrant commented 11 months ago

Does Asyncify seem like a possible solution?

...the idiomatic code in the original systems language uses a blocking API for the I/O, whereas an equivalent example for the web uses an asynchronous API instead. When compiling to the web, you need to somehow transform between those two execution models, and WebAssembly has no built-in ability to do so just yet. This is where Asyncify comes in. Asyncify is a compile-time feature supported by Emscripten that allows pausing the entire program and asynchronously resuming it later.

Say that you have a similar synchronous call somewhere in your Rust code that you want to map to an async API on the web. Turns out, you can do that too!

https://web.dev/articles/asyncify

This seems to prevent the virality of async from contaminating standard Rust.

nathanielsimard commented 11 months ago

This is really interesting. I guess this strategy would work. I'm not sure, though, how to use it from Rust!