FlorianCassayre / DeepLearningUtil

A toy library written in Java to better understand Artificial Neural Networks.
2 stars 0 forks source link

Problem #1

Closed hugo4715 closed 7 years ago

hugo4715 commented 7 years ago

Hi, I'm trying out this lib, and using a simple XOR as the input vector, but i cannot figure out what is wrong with my code. ` public static void main(String[] args) { System.out.println("XOR starting"); final Network network = new Network.Builder(new Dimensions(2)) .hookFullyConnected(new Dimensions(2),ActivationFunctionType.SIGMOID) .hookFullyConnected(new Dimensions(1),ActivationFunctionType.SIGMOID) .build(OutputFunctionType.MEAN_SQUARES);

    Trainer t = new StochasticTrainer(network, 0.1);

    List<Volume> inputs = new ArrayList<>();
    List<Volume> outputs = new ArrayList<>();

    inputs.add(create2D(0,0));
    outputs.add(create1D(0));

    inputs.add(create2D(1,1));
    outputs.add(create1D(0));

    inputs.add(create2D(1,0));
    outputs.add(create1D(1));

    inputs.add(create2D(0,1));
    outputs.add(create1D(1));

    for(int i = 0; i < 10000; i++){
        t.train(inputs.get(i % 4), outputs.get(i % 4));
    }

    network.forwardPropagation(create2D(0,0));
    System.out.println("0;0=" + network.getOutput().get(0));

    network.forwardPropagation(create2D(1,1));
    System.out.println("1;1=" + network.getOutput().get(0));

    network.forwardPropagation(create2D(0,1));
    System.out.println("0;1=" + network.getOutput().get(0));

    network.forwardPropagation(create2D(1,0));
    System.out.println("1;0=" + network.getOutput().get(0));
}

private static Volume create2D(int value1, int value2){
    Volume v = new Volume(new Dimensions(2));
    v.set(0, value1);
    v.set(1, value2);
    return v;
}

private static Volume create1D(int value1){
    Volume v = new Volume(new Dimensions(1));
    v.set(0, value1);
    return v;
}

` Thanks

hugo4715 commented 7 years ago

Hum, i have no clue.

FlorianCassayre commented 7 years ago

Firstly I appreciate the interest you bring for the repository. I want to mention that the library is missing many assertions. Also, every feature has been tested empirically and therefore probably contains bugs/may not work as expected.

Your code is correct and it's actually working as intented. What happened is that the network got stuck in a local optimum. XOR is not an easy function to learn. You gave your network only two hidden units and although there exist a solution for this configuration, it's unlikely to find it. In the example I provided I set the default number of hidden units to 10 and it gave me much better results. Moreover I recommend you to use the Adadelta trainer : you don't have to set a learning rate since it automatically calculates it during the runtime, and it converges very quickly.

Edit: Forgot to amend