Closed inexorabletash closed 1 month ago
The fun never stops!
I may be wrong and the missing validation is intentional - e.g. for lstmCell, cellState's data type and rank are validated, but not the actual dimensions called out in the prose ("The 2-D input cell state tensor of shape [batchSize, hiddenSize].")
Thanks for the close look, @huningxin - all those weight/recurrentWeight/hiddenSize/inputSize blur together after a while.
I noticed one more case where this could be applied: instanceNormalization - bundled it into this PR since it was on topic. Done in fcf0479b84d51af806113c00245aaf35d38922dc
Rather than validating for example that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ».
This is nice and concise.
Some steps in gruCell() were comparing a rank vs. an expected dimension (e.g. "rank is not equal to 3 * hiddenSize"). Fix these!
Rather than validating for example that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ». This also implicitly fixes places that were inspecting shape[x] without validating the rank first. Done for: batchNormalization(), conv2d(), convTranspose2d(), gru(), gruCell(), lstm(), lstmCell().
Some places did validate data type and rank, but only some or none of the dimensions. Make this consistent across the ops - at least, matching the existing prose. Done for gru(), gruCell(), instanceNormalization(), lstm(), lstmCell().
Preview | Diff