attractivechaos / kann

A lightweight C library for artificial neural networks
Other
678 stars 117 forks source link

xor example #14

Closed mazko closed 5 years ago

mazko commented 5 years ago

Hi, here is my code:

// gcc xor.c ../kann.c ../kautodiff.c -I. -I../ -lm && ./a.out

#include "kann.h"

static kann_t *model_gen(int n_in, int n_out, int loss_type, int n_h_layers, int n_h_neurons)
{
  int i;
  kad_node_t *t;
  t = kann_layer_input(n_in);
  for (i = 0; i < n_h_layers; ++i)
    t = kad_relu(kann_layer_dense(t, n_h_neurons));
  return kann_new(kann_layer_cost(t, n_out, loss_type), 0);
}

static void train(kann_t *ann)
{
  enum { n = 4 };

  float *x[n] = {
    (float[]){ 0, 0, },
    (float[]){ 0, 1, },
    (float[]){ 1, 0, },
    (float[]){ 1, 1, },
  };

  float *y[n] = {
    (float[]){ 0, },
    (float[]){ 1, },
    (float[]){ 1, },
    (float[]){ 0, },
  };

  kann_train_fnn1(ann, 0.001f, 64, 10000, 10, 0.1f, n, x, y);
}

void predict(kann_t *ann)
{
  printf("%f | %f\n", *kann_apply1(ann, (float[]){ 0, 0 }), 0.0f);
  printf("%f | %f\n", *kann_apply1(ann, (float[]){ 0, 1 }), 1.0f);
  printf("%f | %f\n", *kann_apply1(ann, (float[]){ 1, 0 }), 1.0f);
  printf("%f | %f\n", *kann_apply1(ann, (float[]){ 1, 1 }), 0.0f);
}

int main(int argc, char *argv[])
{
  kann_t *ann = model_gen(2, 1, KANN_C_CEB, 1, 5);
  train(ann);
  predict(ann);
  kann_delete(ann);

  return 0;
}

Program output:

0.000902 | 0.000000
0.999955 | 1.000000
0.999937 | 1.000000
0.000029 | 0.000000

As far as i know xor requires 3 neurons in hidden layer not 5. Here is keras example:

model = Sequential()
model.add(Dense(3, input_dim=2, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['binary_accuracy'])
model.fit(training_data, target_data, epochs=10000, verbose=2)
print model.predict(training_data)
[[0.0073216]
 [0.9848797]
 [0.9848797]
 [0.0067511]]

Why 5 neurons ?

mazko commented 5 years ago

Changed kad_relu to kad_exp and hidden neurons 5 tio 2. Output:

0.001615 | 0.000000 0.999962 | 1.000000 0.999967 | 1.000000 0.000002 | 0.000000