Closed chenwhql closed 6 years ago
在阅读源码的时候,对PaddlePaddle Fluid,norm_op.h里面的梯度求解公式不太理解,代码如下:
L2 Normalization Code:
// y = x / sqrt((sum(x * x) + epsilon)) // norm = sqrt(sum(x * x) + epsilon) auto sum = x.pow(2).sum(rdim) + eps; norm.device(*place) = sum.sqrt(); // y = x / norm Eigen::DSizes<int, 3> rshape(pre, 1, post); Eigen::DSizes<int, 3> bcast(1, n, 1); y.device(*place) = x / norm.reshape(rshape).broadcast(bcast);
计算公式是:
L2 Normalization Gradient Code:
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)] // = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x)) // = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x)) // 1. sum = sum(x*dy) sum.device(*place) = (x * dy).sum(rdim); // 2. dx = x * sum dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x; // 3. dx / (sum(x*x) + e) // where, norm.pow(2) = sum(x*x) + e, which is calculated in forward. dx.device(*place) = dx / norm.pow(2).broadcast(bcast); // 4. [dy - dx] / sqrt(sum(x*x)) dx.device(*place) = (dy - dx) / norm.broadcast(bcast);
但是我自己求解的结果好像与代码有一点出入,过程如下:
代码里最后dy项为什么是乘以y'的?数学不太好,这里不太理解。
@chenwhql 您好:
BP的反向求导,这里要计算的是其实是 dJ/dx, J是损失函数。
依据链式法则:
dJ/dx = (dJ/dy ) * (dy/dx )
这里dJ/dy就是 y', dy/dx 就是推到的公式。
理解了,谢谢
在阅读源码的时候,对PaddlePaddle Fluid,norm_op.h里面的梯度求解公式不太理解,代码如下:
L2 Normalization Code:
计算公式是:
L2 Normalization Gradient Code:
计算公式是:
但是我自己求解的结果好像与代码有一点出入,过程如下:
代码里最后dy项为什么是乘以y'的?数学不太好,这里不太理解。