webonnx / wonnx

A WebGPU-accelerated ONNX inference run-time written 100% in Rust, ready for native and the web
Other
1.54k stars 54 forks source link

Tips for parallelization #190

Open schell opened 9 months ago

schell commented 9 months ago

First off great work, wonnx has been very easy to use and besides a few missing operators it "just works".

I'm in the optimization phase of building an app that does inference using wonnx. When I benchmark (with criterion) wonnx running a model I've found it's just about as fast as onnxruntime. I figured that this probably has to do with marshaling the data to the GPU (maybe the shader created by wonnx runs a little faster but the marshaling time is a little longer). If that's the case I figured I could get a throughput improvement by running my model in parallel. Unfortunately I am not in control of the models and cannot retrain or re-export the model with a dynamic batch size so instead I opted to edit the onnx model itself and clone the graph into 64 subgraphs, each with its own input. Even though it worked as expected and validated, it provided no gain in throughput (or latency, for that matter). My guess is that the shader that wonnx produces is probably not performing each subgraph in parallel, but I don't know.

My question is - is there a general method of parallelization that might yield "pretty good" results that doesn't involve re-training or other python tasks? I don't mind editing the onnx model to possibly use another method like SequenceMap (if that's supported), or something similar. Or maybe there's an opportunity to expand the wonnx API to support this out-of-the-box? Possibly by issuing multiple draw calls over an offset buffer? What do you think?

pixelspark commented 9 months ago

Thanks and good to hear wonnx is actually being used in production!

The way wonnx runs an ONNX graph is actually pretty simple: after a pass over the graph to perform some optimizations, the graph nodes are topologically sorted, and we generate a shader + invocation for each node (using a coloring algorithm to re-use buffers for intermediate values). The shaders are then invoked in series, which means there is only paralellism executing a single node. For many models this is fine as their graphs are more or less serial and the GPU can be saturated by running just one node's shader at a time.

In theory we could improve on this by executing nodes that are independent in parallel as well (it would however complicate the coloring algorithm and require some sort of waiting mechanism for when parallel branches join back together). Another option would be to attempt to generate one big shader containing all the ops in one (even more complicated).

Your other options are (1) to implement a custom onnx op with its own shader, or (2) to run multiple wonnx graphs (subgraphs of a bigger model) in parallel yourself (and do the joining/waiting at that level). The latter solution would also be useful to gain parallellism in a multi-GPU system.

schell commented 9 months ago

Thanks for the explainer! I'm definitely interested in (1) and (2). Can you elaborate on how I might get started with (1) as well as what (2) would look like today?