FidoProject / Fido

A lightweight C++ machine learning library for embedded electronics and robotics.
http://fidoproject.github.io/
MIT License
437 stars 81 forks source link

Save network training state. #67

Closed mekerhult closed 6 years ago

mekerhult commented 6 years ago

I have seen from the tests that it is possible to save the network, but is it poossible to save its weights efter it has been trained? I.e I would like to train the net and then store it to flash, so that when I boot up my embedded project the next time it is already trained. Is this possible?

Great project btw!

joshuagruenstein commented 6 years ago

This is totally doable. Note the store function of a variety of classes, which all accept streams to write to depending on what exactly you're looking to save. Good luck with your project!

mekerhult commented 6 years ago

Thanks for replying Joshua! Unfortunately, I cannot get that to work. Let me just brief you on my project real quick so you know what I would like to achieve. I want to run Fido on an embedded system (that part is done, compile, etc), and that system shall drive an Open AI simulation (CartPole) on a PC via websockets. I am expecting the learning to take quite some time due to limited CPU power on the embedded system. That's why I would like to save and recall the network weights.

Now, I have planned to use the FidoControlSystem() class to run this, as it seems to be the most advanced algorithm, and works with experience replay. Am I right in this? This class does have a store() function, but how do I load the network?

Thanks, Marcus

joshuagruenstein commented 6 years ago

Two things:

  1. Looking through the code, it seems like only certain classes currently support loading from streams in their constructor. If all you wanted to store were the neural net weights, you could simply modify the constructor for FidoControlSystem to have it intake a stream of weights, pass that down the hierarchy until the network is initialized, then initialize the NeuralNet from that stream. Unfortunately this behavior isn't already implemented in the class, but it should be relatively trivial to accomplish. This also leaves out additional hyperparameters you'd want to save such as adadelta parameters, uncertainty, etc. However, once again this should be simple functionality to add if you want it.
  2. If you want to use FidoControlSystem with CartPole, I suggest you run it locally on your computer first to make sure this is even feasible. Wirefitted q-learning isn't really designed for that sort of task, and I don't know if you'd achieve solid performance. The fido control system also adds a bunch of wackiness that I can't predict how it might perform on a task like that. You might be better off using traditional q-learning with a neural network, which Fido also supports and will provably converge on cartpole. Additionally, while you're correct that FidoControlSystem has experience replay built in while QLearn does not, you should find that experience replay is unnecessary for Cartpole and will simply increase your training time.

In conclusion, I'd suggest using the QLearn class instead with standard backprop, as this will be not only be ideal for your task, but also the simplest to store and load weights through. All you'd have to do is initialize the neural network from your stored parameters, then pass that into the QLearn constructor. This is all already implemented functionality, and if you need any help getting it set up feel free to continue commenting here.

truell20 commented 6 years ago

I'm gonna close this for now.