WebAssembly / wasi-nn

Neural Network proposal for WASI
429 stars 34 forks source link

Eliminate `GraphExecutionContext` #43

Open abrown opened 11 months ago

abrown commented 11 months ago

This issue proposes a simplification to the wasi-nn API: eliminate the GraphExecutionContext state object altogether and instead simply pass all tensors to and from an inference call — compute(list<tensor>) -> result<list<tensor>, ...>. This change would make set_input and get_output unnecessary and they would also be removed.

As background, the WITX IDL is the cause of GraphExecutionContext's existence. As I understood it back when wasi-nn was originally designed, WITX forced us to pass an empty "pointer + length" buffer across to the host so that the host could fill it. This led to get_output(...), which included an index parameter for retrieving tensors from multi-output models (multiple outputs is a possibility that must be handled, though not too common). Because get_output was now separate from compute, we needed some state to track the inference request — GraphExecutionContext.

Now, with WIT, we can expect the ABI to be able the host-allocate into our WebAssembly linear memory for us. This is better in two ways:

One consideration here is ML framework compatibility: some frameworks (e.g., OpenVINO) expose an equivalent to GraphExecutionContext in their external API that must be called by implementations of wasi-nn. But, because this context object can be created inside the implementation, there is no compatibility issue. Implementations of compute will simply do a bit more than they currently do, but no more overall work than they do currently.

Another consideration is memory copying overhead: will WIT force us to copy the tensor bytes across the guest-host boundary in both directions? Tensors can be large and additional copies could be expensive. For output tensors, this may be unavoidable: when the tensor is generated on the host side during inference it must be made accessible to the Wasm guest somehow — copying is a simple solution. For input tensors, though, this discussion might suggest that there is no WIT-inherent limitation to avoid the copy. If tensor copying becomes a bottleneck, perhaps WIT resources could be the solution.

shschaefer commented 9 months ago

@abrown, the session/execution context interface caches the parameterization of configuration and device selection. You are going to leave this state on the graph, to be initialized during model load? And then the call to compute would accept the graph as an input instead of the GraphExecutionContext ?

The additional scenario for execution context is passing in this parameterization across multiple models (chaining) or multiple modes - natively and in the browser, we pass the GPU context between DirectX and ONNX or WebGPU and WebNN.