webmachinelearning / webnn

🧠 Web Neural Network API
https://www.w3.org/TR/webnn/
Other
335 stars 42 forks source link

Consider removing `lstm` and `gru` operators #689

Open a-sully opened 2 weeks ago

a-sully commented 2 weeks ago

Related to #453. One recommendation of that issue is:

  • Remove model-specific instructions like LSTM and gru

Current support across backends

Operator DirectML CoreML TFLite
lstm ⚠️†
gru -

†TFLite delegates generally do not support the entire TFLite opset. For instance, TFLite GPU delegates only support a very basic variant of LSTM which does not support most of the parameters specified by WebNN.

What does "supporting" LSTM really mean?

Higher-level operators such as lstm and gru tend to have more knobs to turn, and each additional knob increases the likelihood that WebNN will not actually be able to use the backend's operator of the same name. There are many variants of LSTM. Just because a backend has an LSTM operator does not mean that operator can express everything required by the variant of LSTM that WebNN specifies. Meanwhile, frameworks sitting on top of WebNN may only take advantage of WebNN’s lstm operator if it is exactly the variant the calling framework needs.

For example, TFLite's variant of LSTM uses coupled "input" and "forget" gates (CIFG), whereas this option is not available in WebNN - Chromium's DML implementation currently does not couple these gates. User agents implementing WebNN on TFLite cannot use its LSTM operator, and neither can frameworks calling into WebNN use its LSTM operator if they want the CIFG behavior.

Let's look at the problem @philloooo mentioned about the activations supported by LSTM across various backends: https://github.com/webmachinelearning/webnn/issues/573#issuecomment-1984086150

Supported activation functions for LSTM

Operator ONNX DirectML CoreML TFLite WebNN
affine optional ??? - - -
clamp - ??? - -
elu optional ??? - -
gelu - ??? - -
hardSigmoid optional ??? ⚠️†† -
hardSwish - ??? - -
leakyRelu optional ??? - -
linear - ??? ⚠️†† -
relu ✅†††
reluN1To1 - ??? - -
relu6 - ??? - -
scaledTanh optional ??? ⚠️†† - -
sign - ??? - -
softmax - ??? - -
softplus optional ??? - -
softsign optional ??? - -
sigmoid ✅††† -
tanh ✅†††
thresholdedRelu optional ??? - - -

†I couldn't find documentation in DirectML saying which activations are supported by which operators. @ folks familiar with DML, please feel free to chime in!

††Does not support passing alpha nor beta values, as far as I can tell

†††I'm assuming, since these activations are not listed as optional by ONNX


Aside: Now that MLActivations are no longer used for op fusion, we should consider removing MLActivations which do not make sense for use with recurrent operators.


What activations can be specified on each backend?

Let's also remember that WebNN's lstm operator has multiple activations.

Gates DirectML† CoreML TFLite WebNN
input (i) f() recurrent_activation fused_activation_function activations[0]
output (o) f() recurrent_activation fused_activation_function activations[0]
forget (f) f() recurrent_activation fused_activation_function activations[0]
cell (g) g() cell_activation always sigmoid, I think? activations[1]
hidden (h) h() activation always sigmoid, I think? activations[2]

†DML also supports passing different activations for LSTM's forward and backward passes

Summary

Reviewing the Guildelines for new operations to retroactively evaluate whether lstm and gru meet these guidelines, this stands out:

  • Prefer primitives over new high level operations but consider performance consequences

The rationale for having these operators in WebNN is that the user agent's implementation of WebNN can plumb this directly to the backend's operator of the same name; otherwise there's no benefit compared to having the framework sitting on top of WebNN decompose the operator itself.

While there is some overlap - DML and CoreML are the most similar - there are still far more differences than similarities. For a web developer looking to deploy a model on WebNN across platforms, the tables suggest that aside from a few exceptional cases, these operators would need to be decomposed by the user agent.

If a high-level operator:

...should it still be in WebNN? :)

Options:

  1. Remove these operators from WebNN
  2. Keep these operators, understanding that they will often be decomposed
  3. Water down these operators into a least-common-denominator variant
    • Open question: Would real models use these watered-down variants?