An issue was reported where onnx-dml would give wildly bad evaluation in fp16 (default) mode for the following position: rnbq1rk1/p1p3pp/4pp2/1p1n4/P1pP2N1/6P1/1PQbPPBP/R1B2RK1 b - - 1 12
This was tracked down to the LayerNormalization operator right after the ReLU^2 activation function - the input data had a variance of 246419, which is outside the fp16 range (even though the onnx docs say the normalization is done in fp32 by default).
To fix this, we scale the input of the LayerNormalization operator by a sufficient amount.
Some more details:
The operations after the previous normalization stage implement the following:
w2*ReLU(w1*(γ1*N+β1)+b1)^2+b2+α*(γ1*N+β1)
with N the normalized input. The first commit changes this to
w2/α*ReLU(w1*(γ1*N+β1)+b1)^2+b2/α+(γ1*N+β1)
which helps a bit (as α>0) and is also a small speed-up as we can do some constant folding.
To fully fix the issue some more scaling is needed, which is provided by k in the following:
w2/α/k*ReLU(w1*(k*γ1*N+k*β1)+k*b1)^2+k*b2/α+(k*γ1*N+k*β1)
The value of k was chosen conservatively as the next power of 2 after max(|γ1|)*max(|w1|), applied only for fp16 when max(|γ1|)*max(|w1|) > 1, and only if the activation function is ReLU(.)^2.
An issue was reported where onnx-dml would give wildly bad evaluation in fp16 (default) mode for the following position:
rnbq1rk1/p1p3pp/4pp2/1p1n4/P1pP2N1/6P1/1PQbPPBP/R1B2RK1 b - - 1 12
This was tracked down to the LayerNormalization operator right after the ReLU^2 activation function - the input data had a variance of 246419, which is outside the fp16 range (even though the onnx docs say the normalization is done in fp32 by default). To fix this, we scale the input of the LayerNormalization operator by a sufficient amount.Some more details: The operations after the previous normalization stage implement the following:
w2*ReLU(w1*(γ1*N+β1)+b1)^2+b2+α*(γ1*N+β1)
withN
the normalized input. The first commit changes this tow2/α*ReLU(w1*(γ1*N+β1)+b1)^2+b2/α+(γ1*N+β1)
which helps a bit (asα>0
) and is also a small speed-up as we can do some constant folding. To fully fix the issue some more scaling is needed, which is provided byk
in the following:w2/α/k*ReLU(w1*(k*γ1*N+k*β1)+k*b1)^2+k*b2/α+(k*γ1*N+k*β1)
The value ofk
was chosen conservatively as the next power of 2 aftermax(|γ1|)*max(|w1|)
, applied only for fp16 whenmax(|γ1|)*max(|w1|) > 1
, and only if the activation function isReLU(.)^2
.