WebAssembly / wasi-nn

Neural Network proposal for WASI
448 stars 35 forks source link

Question: GEMM on 8-bit buffers? #4

Closed abrown closed 10 months ago

abrown commented 3 years ago

This issue serves as a follow up to the following discussion in the W3C Machine Learning Workshop. The question was about fp16 and i8 support in Wasm, specifically related to ML models that may need these data types:

From Sangwhan Moon to Everyone:  07:41 AM
https://github.com/WebAssembly/simd/tree/master/proposals/simd is on-going work, I recall seeing i8 types in there.

From Kenneth Heafield (University of Edinburgh) to Everyone:  07:41 AM
https://github.com/webmachinelearning/webnn/issues/84

From Andrew Brown to Everyone:  07:42 AM
While it is true that Wasm doesn't have f16 and i8 types, it is possible to create buffers in memory and pack them (through shifting, etc.) so they would "look like" f16/i8 buffers--is this not enough?

From Kenneth Heafield (University of Edinburgh) to Everyone:  07:44 AM
While we're talking about size, speed matters.  The relevant WebAssembly issue is https://github.com/WebAssembly/simd/issues/328 .
So WasiNN would do GEMM for me on the 8-bit buffers?

From Andrew Brown to Everyone:  07:45 AM
I believe it could; right now we are working on a POC that exposes what OpenVINO can do through the wasi-nn API

From Kenneth Heafield (University of Edinburgh) to Everyone:  07:47 AM
Can it call _mm512_permutexvar_epi16 to implement lookup tables for operators?  And if all I have is an Intel CPU, will WebAssembly allow it to call pmaddubsw or vpmaddubsw depsite the Intel-specific saturation behavior that doesn't exist on ARM / GPU?
abrown commented 3 years ago

To answer Kenneth's question, let me provide a thought: wasi-nn exists because platform-specific operations (like the ones you mention perhaps) are unlikely to be exposed through the Wasm SIMD specification. That specification (and Wasm in general) has prioritized portability so it is difficult to expose Intel-specific operations (or any other platform, really) directly. Enter WASI and this proposal, wasi-nn: by exposing ML functionality as a system interface, we can then implement the ML functionality using optimized, platform-specific operations, which should give you access to the operations you are looking to use. Some caveats:

abrown commented 3 years ago

cc: @mingqiusun

kpu commented 3 years ago

Background: https://browser.mt/ aims to run client-side machine translation with browser integration https://www.w3.org/2020/06/machine-learning-workshop/talks/privacy_focused_machine_translation_in_firefox.html . We're running natively in reasonable speed https://neural.mt/speed/ with 8-bit GEMM dispatched by CPUID with different SIMD lengths. But if this is an extension then we're stuck with (currently) slow web APIs.

8-bit GEMM is our most expensive kernel and we want it in Firefox. Keep in mind that I also care about different matrix sizes than the vision people https://github.com/apache/incubator-mxnet/issues/17980 .

We can export to 3 ONNX graphs that get glued together. The exported models are somewhat inefficient though, even natively, because shortlisting is crucial to performance. In shortlisting, the system guesses what words will occur then selects them in the output matrix, avoiding a full multiply of the output matrix. Those guesses are based on set operations. So I'm hesitant to go for a full "give us your graph" approach when much of the work to get the speed entailed customizing our C++ toolkit including operators that don't exist in ONNX. But if I can just call GEMM, logsoftmax, elementwise kernels, etc. that's most of what I need.

mingqiusun commented 3 years ago

@kpu Is there any machine learning framework that supports shortlisting?

kpu commented 3 years ago

Sockeye, OpenNMT, and Marian all do shortlisting. Sockeye does it python side because MXNet doesn't have it. OpenNMT and Marian have integrated C++ stacks with CPU and GPU backends. It's not hard per se. Read the input, do some hash table operations to take the union of predicted tokens for each token in the batch, and run a select operation on the output matrix. Just not something ONNX supports out of the box.

In any case, my main performance interest is getting 8-bit GEMM as fast as possible in the browser as fast as possible though whichever standard. The other kernels are icing.

kpu commented 3 years ago

Paging @bjacob from https://github.com/WebAssembly/simd/issues/328

kpu commented 3 years ago

Let's not worry about the shortlisting; I can just do that in WebAssembly with a hash table and provide it as an extra input.

What I do want is 8-bit GEMM in the browser.

I feel like WebNN is pursuing a full-package approach which will be nice in the long term "Expected completion: [CR Q1 2022]" that is much bigger than what I need to get reasonable efficiency.

abrown commented 3 years ago

Thinking about this more, the Wasm SIMD repo issues (e.g. https://github.com/WebAssembly/simd/pull/127, https://github.com/WebAssembly/simd/issues/328, https://github.com/WebAssembly/simd/issues/224) and your comments there seem the most likely way to get, e.g., PMADDUBSW in the browser. WASI and modules like wasi-nn are not primarily aimed at browser consumption though at some point someone may make that work.

geekbeast commented 11 months ago

It's not a good way to get into browser, since browsers are unlikely to support custom operators for security reasons.

@abrown I feel like this issue might be a good candidate for getting closed resolved both for fit and inactivity.

kpu commented 11 months ago

It's not a good way to get into browser, since browsers are unlikely to support custom operators for security reasons.

I find the timing here a bit ironic given that Firefox 118 just launched in-browser machine translation https://www.mozilla.org/en-US/firefox/118.0/releasenotes/ powered by a custom 8-bit GEMM operator because WASM was too slow.

geekbeast commented 11 months ago

Hi, sorry I wasn't clear. Browsers are unlikely to support custom user defined operators loaded from the internet. That means that browsers would not only have to support wasi-nn, but they would have to provide an implementation linking against existing framework backends. I believe this is what @abrown was alluding to in his comment.

I'm glad that Firefox decided to implement their own 8-bit GEMM operator and I hope that it meets your needs.

abrown commented 10 months ago

Let's close this: @kpu's use cases are more browser-specific and WebNN is the better fit for that.