kim-marcel / basic_neural_network

A very basic Java Neural Network Library.
MIT License
39 stars 14 forks source link

wrong hidden error computation #7

Open Orbiter opened 3 years ago

Orbiter commented 3 years ago

in your lines https://github.com/kim-marcel/basic_neural_network/blob/master/src/main/java/basicneuralnetwork/NeuralNetwork.java#L165-L169

                weights[n - 1] = weights[n - 1].plus(deltas);

                // Calculate and set target for previous (next) layer
                SimpleMatrix previousError = weights[n - 1].transpose().mult(errors);
                target = previousError.plus(layers[n - 1]);

you calculate the target for the next (lower) level using the matrix previousError. That uses the weights weights[n - 1] which had previously been altered with the correction factor. What you require here is the unaltered weight to make the lower layer of the network backpropagation dependent on the original error with the original weight, not the altered weight.

From my point of view this can be fixed by moving the previousError line above the weights computation, like:

                SimpleMatrix previousError = weights[n - 1].transpose().mult(errors);
                weights[n - 1] = weights[n - 1].plus(deltas);
                target = previousError.plus(layers[n - 1]);