webmachinelearning / webnn

🧠 Web Neural Network API
https://www.w3.org/TR/webnn/
Other
354 stars 45 forks source link

Consider adopting new broadcasting rules #590

Open a-sully opened 4 months ago

a-sully commented 4 months ago

Reviving a discussion from #534, which defined shape broadcasting but didn't touch on the question of what WebNN's shape broadcasting rules should be

WebNN currently specifies two kinds of broadcasting rules: unidirectional and bidirectional

Of the popular ML frameworks, ONNX (which WebNN is largely based on) appears to be an outlier in making a distinction between "unidirectional" and "multidirectional" broadcasting. This distinction is not made by:

The "unidirectional broadcastable" constraint of some ONNX ops (e.g. prelu()) requires workarounds when exporting from other formats to ONNX - like in this example of using TVM to export PyTorch to ONNX: https://github.com/pytorch/pytorch/issues/70570#issuecomment-1034379620.

What should we do?

Option 1: Adopt Numpy's broadcasting rules

Rationale: Numpy's broadcasting rules are a standard across the industry. It seems reasonable to be what we expose to the web

Outcome: "bidirectional broadcasting" will be the only type of broadcasting exposed to the web. The user agent muse ensure that the constraints of the underlying framework - such as unidirectional broadcasting for ONNX (@fdwr has suggested that this is trivial), and lack of inferred broadcasting specifications for XLA (more on that below) - are satisfied.

Option 2: Adopt XLA's broadcasting rules

Rationale: The XLA Principles apply to WebNN, too:

The XLA language is as strict and explicit as possible, avoiding implicit "magical" features. Such features might make some computations slightly easier to define, but at the cost of more assumptions baked into user code that will be difficult to change in the long term. If necessary, implicit magical features can be added in client-level wrappers... With regard to broadcasting, XLA requires explicit broadcasting specifications on operations between arrays of different ranks. This is different from NumPy, which infers the specification when possible.

Outcome: Both "unidirectional broadcasting" and "bidirectional broadcasting" concepts would be removed from the WebNN spec. To facilitate explicit broadcasts, something like StableHLO's broadcast_in_dim op would need to be added to WebNN

Option 3: Keep the status quo

Rationale: It's the status quo

Outcome: No action needed regarding the current spec. However, all models ported to WebNN will need to abide by this "unidirectionally broadcastable" constraint which is specific to ONNX

fdwr commented 4 months ago

Yo Austin - if the spec is unclear, then yeah, it should be made so. Before my thoughts, let's first breakdown broadcasting into its three parts:

For XLA, step 1 does not happen because it expects the ranks already match. Step 2 uses bidirectional broadcasting for the elementwise operators, and XLA's BroadCastInDim uses undirectional broadcasting of the input shape and expected output shape, even if they don't say it by name.

For NumPy, all 3 steps happen, step 1 is right aligned (though there are likely cases of middle aligned broadcasts too given axes, at least internally for processing), and step 2 is bidirectional, except in the case of its own broadcasting operator broadcast_to which uses unidirectional broadcasting even if it they don't say it by name (e.g. this works numpy.broadcast_to([1, 2, 3], (3, 3)), while this fails numpy.broadcast_to([1, 2, 3], (3, 1)) because the input shape is not undirectionally broadcastable to the output shape) .

And my thinking so far:

So, I'm partial to an option 4:

Option 4

@huningxin?

a-sully commented 3 months ago

Thank you @fdwr for the very thorough response! I think your Option 4 proposal makes sense, with one addendum

My primary motivations for filing this issue were:

It seems that I've been successful in that latter point :)

here's another case of balancing backend complexity vs front-end caller complexity ⚖

In the spirit of https://github.com/extensibleweb/manifesto I'm generally in favor of pushing complexity to callers (e.g. "This leads to better performance with less implementation effort"). In this case, I didn't expect that we'd actually adopt XLA's broadcasting rules for WebNN, though I figured it was worth calling it out as the option on the furthest towards the "caller complexity" end of things :P


As for the follow-up question... Regarding:

any WebNN operators that use broadcasting should be clear that they do, rather than something that implicitly happens for any operator

I agree! Related to that:

you can find rare occurrences of left alignment and even middle axis alignment, such as with instanceNormalization's scale and bias in the decomposition algorithm (e.g. [7] -> [1,7,1,1]]) and then:

Is this middle axis alignment perhaps only relevant when using NCHW layout? If we were using NHWC layout, would [7] broadcast to [1, 1, 1, 7]?

Regardless, the spec of instanceNormalization doesn't say anything about broadcasting. Let's add a fourth action item to Option 4?