BrainJS / brain.js

πŸ€– GPU accelerated Neural networks in JavaScript for Browsers and Node.js
https://brain.js.org
MIT License
14.36k stars 1.07k forks source link

Faster training #102

Closed arguiot closed 6 years ago

arguiot commented 6 years ago

Hello πŸ‘‹, I would like to know if it would be possible to improve performances of the train function, using the GPU, or using multiple cores/threads in the node version.

Thank you for building this library

robertleeplummerjr commented 6 years ago

The short answer: yes The longer answer: This is something we've put a great deal of thinking into, and ended up adopting gpu.js and assisting getting it to version 1. I just merged in an experimental branch this is where the above answer "yes" is satisfied. That gives us our short term gain for feed forward neural networks.

The long term "yes" is that we are (re-?)structuring the recurrent architecture (which is still experiment) to create a single unified way of composing layers. This branch: https://github.com/BrainJS/brain.js/tree/nn-gpu-layers embodies where we are going next. Not only will it be fast (gpu), but easy, and straightforward. This is something we feel is greatly missing from the machine learning community. The api is simple:

const net = new FeedForward({
  inputLayer: () => { /* return a layer */ }
  hiddenLayers: [() => { /* return a layer */ },() => { /* return a layer */ },() => { /* return a layer */ }],
  outputLayer: () => { /* return a layer */ }
});

What is most important about this is that it will run on GPU first, with a CPU fallback, and that the Recurrent neural network will share the same basic architectural frontend, so no matter if you are building a simple perceptron, or a Vincent Van Gogh imagine manipulator, or a Johnny5 robot, or a cancer detector, you'll have the same easy means of composition.

I could go on and on and on, this has been far over a year in the making representing a lot of very smart people asking if the way current neural networks have been implemented right, and the answer we keep coming back to is "no", but I digress, and I thank you for asking this very fundamental question.

nsantini commented 6 years ago

Ive been trying to get NeuralNetworkGPU to work but there seems to be a bug in it, see https://github.com/BrainJS/brain.js/issues/116

arguiot commented 6 years ago

@nsantini Ok. I hope this bug will be fixed soon. I’ll try to investigate and if I find any solution to the problem I’ll answer in #116

robertleeplummerjr commented 6 years ago

I've been thinking about this all day, it has to do with the readPixels from webgl, which is notoriously slow because it causes all layers to have to sync and stop. I believe I have a possible solution where we read the outputs of all the trained neural networks to a single texture, and then sum the texture to obtain the single error of the whole net. This value could then be output using readPixels because it is essentially a single pixel, and shouldn't cost too much resources, the other option is that we render the output of that last network to a webgl canvas.

I started brainstorming that here: https://github.com/gpujs/gpu.js/issues/240 so worst case scenario, if we have to avoid readpixels alltogether, we get something like: https://www.shadertoy.com/view/ldjXDD

Super pretty.

robertleeplummerjr commented 6 years ago

ok, I'm way off, #116 is the culprit here. I'm solving this tonight!

robertleeplummerjr commented 6 years ago

ok, #116 is resolved, but I'm getting exploding gradients I believe, still investigating.

arguiot commented 6 years ago

Ok, it’s great πŸ‘ you fixed the first bug. Do you think we can expect a GPU based training in the next weeks or days if you fix the gradient issue?

robertleeplummerjr commented 6 years ago

I that is a good target, or I'll loose my sanity 😎

robertleeplummerjr commented 6 years ago

@arguiot v1.0.1 handles the above issue, and is released to npm. I however don't want to close this issue till we can prove it is actually faster on the GPU. It will be, but we were having some issue with the amount of times readPixels was being called. I will try and look into this today.