webmachinelearning / webnn

🧠 Web Neural Network API
https://www.w3.org/TR/webnn/
Other
341 stars 42 forks source link

Simplify, correct, and add validation for GRU/LSTM and friends #659

Closed inexorabletash closed 1 month ago

inexorabletash commented 1 month ago

Preview | Diff

inexorabletash commented 1 month ago

The fun never stops!

I may be wrong and the missing validation is intentional - e.g. for lstmCell, cellState's data type and rank are validated, but not the actual dimensions called out in the prose ("The 2-D input cell state tensor of shape [batchSize, hiddenSize].")

inexorabletash commented 1 month ago

Thanks for the close look, @huningxin - all those weight/recurrentWeight/hiddenSize/inputSize blur together after a while.

inexorabletash commented 1 month ago

I noticed one more case where this could be applied: instanceNormalization - bundled it into this PR since it was on topic. Done in fcf0479b84d51af806113c00245aaf35d38922dc

fdwr commented 1 month ago

Rather than validating for example that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ».

This is nice and concise.