Marfein / -12

0 stars 0 forks source link

Обучить однослойный перцептрон распознаванию арабских цифр #2

Open Marfein opened 1 year ago

Marfein commented 1 year ago

import java.util.Random;

public class Perceptron {

private double[] weights;
private double learningRate;

public Perceptron(int inputSize, int outputSize, double learningRate) {
    weights = new double[inputSize * outputSize];
    this.learningRate = learningRate;
    initializeWeights();
}

public void train(double[][] inputs, int[] labels, int epochs) {
    for (int epoch = 0; epoch < epochs; epoch++) {
        for (int i = 0; i < inputs.length; i++) {
            double[] input = inputs[i];
            int label = labels[i];
            double[] output = predict(input);
            adjustWeights(input, output, label);
        }
    }
}

public int predict(double[] input) {
    double[] output = predict(input);
    int maxIndex = 0;
    for (int i = 0; i < output.length; i++) {
        if (output[i] > output[maxIndex]) {
            maxIndex = i;
        }
    }
    return maxIndex;
}

private double[] predict(double[] input) {
    double[] output = new double[weights.length];
    for (int i = 0; i < weights.length; i++) {
        output[i] = input[i % input.length] * weights[i];
    }
    return output;
}

private void adjustWeights(double[] input, double[] output, int label) {
    for (int i = 0; i < output.length; i++) {
        double delta = (i == label ? 1 : 0) - output[i];
        for (int j = 0; j < input.length; j++) {
            weights[i * input.length + j] += learningRate * delta * input[j];
        }
    }
}

private void initializeWeights() {
    Random random = new Random();
    for (int i = 0; i < weights.length; i++) {
        weights[i] = random.nextDouble() * 2 - 1;
    }
}

}

Marfein commented 1 year ago

public class Main {

public static void main(String[] args) {
    // создаем перцептрон
    Perceptron perceptron = new Perceptron(100, 10, 0.1);

    // обучаем перцептрон на наборе данных
    double[][] inputs = {
            {
                1, 1, 1, 1, 0,  // цифра 0
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 0, 0, 1,
                1, 1, 1, 1
            },
            {
                0, 0, 1, 1, 0,  // цифра 9
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1,
                0, 0, 0, 1
            }
    };
    int[] labels = {0, 9};
    perceptron.train(inputs, labels, 100);

    // тестируем работу перцептрона на новых данных
    double[] input = {
            0, 1, 1, 1, 0,  // цифра 2
            0, 0, 0, 1,
            0, 0, 1, 0,
            0, 1, 0, 0,
            1, 0, 0, 0,
            1, 1, 1, 1,
            0, 0, 0, 0,
            0, 0, 0, 0,
            0, 0, 0, 0,
            0, 0, 0, 0
    };
    int predictedLabel = perceptron.predict(input);
    System.out.println("Predicted label: " + predictedLabel);  // должно быть 2
}

}