Open hmwildermuth opened 8 years ago
Here is also some code I wrote to run tests on the outputted network: (Uses the same ocr.h)
#include "ocr.h"
void printUsage(char const * firstarg) {
std::cout << "Usage:" << std::endl;
std::cout << "\t" << firstarg << " [-v] [-c NUMBER] <filename>" << std::endl;
std::cout << "Possible arguments:" << std::endl;
std::cout << "\t-v: Verbose mode" << std::endl;
std::cout << "\t\tIf not specified, defaults to off" << std::endl;
std::cout << "\t-c: Count (number of tests run)" << std::endl;
std::cout << "\t\tIf not specified, defaults to all tests (10,000)" << std::endl;
}
void printScore(int right, int wrong) {
std::cout << "Final Results:" << std::endl;
std::cout << "\tCorrect: " << right << std::endl;
std::cout << "\tWrong: " << wrong << std::endl;
std::cout << "\tTests: " << right+wrong << std::endl;
std::cout << "\tPercent Correct: " << double(right)/double(right+wrong)*100.0 << "%" << std::endl;
}
int main(int argc, char const *argv[]) {
std::string inputFilename = "ocr.txt";
bool verbose = false;
bool setFilename = false;
int checkSize = -100;
if (argc > 1) {
for (size_t i = 1; i < argc; i++) {
if (!strcmp(argv[i], "-v")) {
verbose = true;
} else if (!strcmp(argv[i], "-c")) {
i++;
checkSize = atoi(argv[i]);
} else {
setFilename = true;
inputFilename = argv[i];
}
}
}
if (!setFilename) {
printUsage(argv[0]);
return -1;
}
std::ifstream myfile;
myfile.open(inputFilename);
net::NeuralNet network = net::NeuralNet(&myfile);
myfile.close();
std::string lbels = "t10k-labels.idx1-ubyte";
std::string imges = "t10k-images.idx3-ubyte";
int mgicNum;
int sizeNum;
std::cout << (verbose ? "Loading images from files...\n" : "");
auto inputArr = read_mnist_images(imges, mgicNum, sizeNum);
auto outputArr = read_mnist_labels(lbels, mgicNum);
net::NeuralNet neuralNetwork = net::NeuralNet(sizeNum, 10, 1, sizeNum, "sigmoid");
std::vector< std::vector<double> > input;
std::vector< std::vector<double> > correctOutput;
std::cout << (verbose ? "Loading into vector...\n" : "");
for (size_t i = 0; i < mgicNum; i++) {
std::vector<double> imgeArr;
for (size_t j = 0; j < sizeNum; j++) {
imgeArr.push_back(double(inputArr[i][j])/double(255));
}
//std::cout << imgeArr.size() << "; " << sizeNum << "\n";
input.push_back(imgeArr);
correctOutput.push_back(digits(outputArr[i]));
}
std::cout << (verbose ? "Done with loading.\n" : "");
std::cout << (verbose ? "Freeing memory...\n" : "");
delete [] inputArr; // <- Is this how you use delete? idk
delete [] outputArr;
// free(inputArr);
// free(outputArr);
std::cout << (verbose ? "Done with freeing memory.\n\n" : "");
if (checkSize == -100) {
checkSize = correctOutput.size();
} else if (checkSize > correctOutput.size()) {
throw std::runtime_error("Number of tests wanted is larger than actual number of tests: " + std::to_string(checkSize));
} else if (checkSize < 1) {
throw std::runtime_error("Invalid number for count: " + std::to_string(checkSize));
}
int wrong = 0;
int right = 0;
for (size_t i = 0; i < checkSize; i++) {
int out = findTop(neuralNetwork.getOutput(input[i]));
int real = findTop(correctOutput[i]);
if ( out != real ) {
if (verbose) {
std::cout << "Failed test #" << i << ":"<< std::endl;
std::cout << "\tTest answer: " << out << std::endl;
std::cout << "\tCorrect answer: " << real << std::endl;
}
wrong++;
} else {
right++;
}
}
printScore(right, wrong);
return 0;
}
Did you solve it? I am having the same issue... My program keeps getting killed by the system.
I was trying to make an OCR neural network using the MNIST OCR image library, however my process was killed every time I ran it by a kernel process called the OOM Killer. It kills processes which use too much memory. I am not sure whether this is because of my code, or something about the backpropagation code. Either way, any help would be appreciated.
also, just to note, when I run the program with the learning sample size cut down to only 250 images, it works, but above 500 it fails.
The C++ File:
The header file which contains functions for loading test images and picking highest members of arrays: (The MNIST functions I copied from somewhere else)