BrainJS / brain.js

🤖 GPU accelerated Neural networks in JavaScript for Browsers and Node.js
https://brain.js.org
MIT License
14.25k stars 1.06k forks source link

Add custom loss functions and a R/W state matrix #935

Open voidvoxel opened 2 weeks ago

voidvoxel commented 2 weeks ago

Sleepy

Summary

Feature branch

Wow. I'm mentally and physically drained after implementing this, but it was SOOOOOOOOOOOOO WORTH IT!!!!

I added support for defining custom loss functions, as well as a R/W state matrix to serve as RAM to allow for training by ruleset in addition to, or even in replacement of, a pre-existing set of training data. In my own benchmarks, a custom loss function designed to train for XOR improved training on the XOR data set by approximately 100% for NeuralNetwork (CPU) and 400% for NeuralNetworkGPU (GPU). This is, of course, a really specific example. However, it's just that; an example. This paves the way for any custom loss function to be defined. On GPU, you obviously are limited to values you pass through the NeuralNetwork.ram property, which is the R/W state matrix I previously mentioned. However, on CPU, you're honestly not limited by anything, as the function is called from the CPU, not the GPU, meaning you could hypothetically even go so far as to involve API calls in the loss function calculation if you really wanted to so long as your internet bandwidth permits without becoming too major of a bottleneck in training times.

Basic examples

Feed-forward

// Create a new neural network.
const net = new NeuralNetwork();

// Fabricate some training data.
const trainingData = [
  [0, 0, 0],
  [0, 1, 1],
  [1, 0, 1],
  [1, 1, 0]
].map(v => ({ input: v.slice(0, 2), output: v.slice(2) }) );

// A custom loss function designed to train for XOR calculations.
// This function is so effective that you could actually train on
// random input data *(`Math.random()` as input data, for example)*
// and the neural network would still come to the correct conclusion
// with little *(if any)* difference in training times.
function loss(actual, expected, inputs, ram) {
  // Calculate the base loss. This is just a normal loss function so far.
  const loss = expected - actual;
  // Reward positive behavior by providing lower error values if the neural network predicts the calculation correctly.
  if (Math.round(actual) !== Math.round(inputs[0]) ^ Math.round(inputs[1]) loss *= 20;
  // Return the calculated loss.
  return loss; 
}

// Define the training options.
const trainOptions = {
  errorThresh: 0.011,
  iterations: 15000,
  loss
};

// Train the neural network using the custom loss function.
net.train(trainingData, trainOptions);

// Calculate a ^ b
function xor(a, b) {
  return Math.round(net.run([a, b])[0]);
}

// Try it out!
console.log(xor(0, 0));
console.log(xor(0, 1));
console.log(xor(1, 0));
console.log(xor(1, 1));

The updateMemory function is easy to implement as well. Here's a silly example that just randomizes the R/W state matrix:

// Define an `updateRAM` function.
function updateRAM(
  ram,
  inputs,
  outputs,
  sizes,
  loss,
  deltaLoss
) {
  const layer = this.thread.z;
  const neuron = this.thread.y;
  const signal = this.thread.x;

  // Maintain the same signal magnitude + add a random value between 0-1.
  return ram[layer][neuron][signal] + Math.random();
}

// Set the `updateRAM` function of your neural network.
const net = new NeuralNetwork();
net.updateRAM = updateRAM;

which will cause each value in NeuralNetwork.ram to increase by a random value between 0-1 each time the neural network is fed input data.

Motivation

Why are we doing this? What use cases does it support? What is the expected outcome? Given a method of explaining not just what the neural network should learn, but now also how it should learn, we have the tools we need to make far-more advanced models, as well as adding some new foundational pieces, allowing for higher-level neural net types to build upon the custom loss function feature as well as the (very-intentionally) general-purpose R/W state matrix.

Some foreseeable use cases: