webmachinelearning / webnn

🧠 Web Neural Network API
https://www.w3.org/TR/webnn/
Other
397 stars 48 forks source link

gather(): Address indices validation and other algorithm nits #642

Closed inexorabletash closed 7 months ago

inexorabletash commented 7 months ago

Fixes #486 Fixes #484


Preview | Diff

inexorabletash commented 7 months ago

I noticed the other glitches in gather(), then went looking for open bugs and tackled #486 at the same time. Wheeee!

inexorabletash commented 7 months ago

@a-sully can you take a first look?

inexorabletash commented 7 months ago

Do we want to tweak this at all to also tackle #484 or leave that for another PR?

a-sully commented 7 months ago

Ah good question...

To specify support for negative indices in WebNN, we need to assume it's reasonable that every WebNN backend (that doesn't already support negative indices) will be able to transform these values at runtime. For example, the simplest way for our WebNN implementation to polyfill is by inserting if and add operations:

if index > 0:
  return index
else:
  return index + input.dimensions[axis]

This assumes that an if operator is available. Looking at #484, TFLite and StableHLO are called out as the backends which don't support negative indices. Both of those backends appear to have an if operator 1 2, though I'm not sure whether e.g. if is supported on all TFLite backends we care about

This becomes easier if we say that WebNN must support control flow ops #559

There may be some other more complicated way to transform these values, too...

@huningxin any thoughts?

huningxin commented 7 months ago

@a-sully

This becomes easier if we say that WebNN must support control flow ops #559

There may be some other more complicated way to transform these values, too...

It might be emulated by where:

add(indices, where(lesser(indices, constant(0)), constant(input.dimensions[axis]), constant(0)))
a-sully commented 7 months ago

Thanks, @huningxin!

To specify support for negative indices in WebNN, we need to assume it's reasonable that every WebNN backend (that doesn't already support negative indices) will be able to transform these values at runtime

Seems reasonable. @inexorabletash shall we fold that into this PR (a nonnormative note to implementers, perhaps?) and close #484?

inexorabletash commented 7 months ago

Seems reasonable. @inexorabletash shall we fold that into this PR (a nonnormative note to implementers, perhaps?) and close #484?

I added a sentence to the note in 446bc773590c1e8814e7b593c6913da7fd3f4b09 - I kept it short, and didn't provide the decomposition. WDYT?

a-sully commented 7 months ago

LGTM, especially since there are multiple reasonable decompositions

inexorabletash commented 7 months ago

@fdwr can you do a final review and merge if it looks good to you? Thanks!

fdwr commented 7 months ago

It might be emulated by where:

add(indices, where(lesser(indices, constant(0)), constant(input.dimensions[axis]), constant(0)))

Yep, that looks right. An alternative would be...

where(lesser(indices, constant(0)), indices, add(indices, constant(input.dimensions[axis])))

...which is slightly shorter (one less constant(0)), but it's a tossup which one is more efficient 🤔.

This assumes that an if operator is available

Btw, the former name of where was elementwiseIf, which makes it more immediately clear to those less familiar with ML that it's the common select or ternary operator. Though, the new name is more discoverable to people who already know TF and PT ⚖️.

inexorabletash commented 7 months ago

Thanks @fdwr - I incorporated your change then fixed one grammar glitch.