accord-net / framework

Machine learning, computer vision, statistics and general scientific computing for .NET
http://accord-framework.net
GNU Lesser General Public License v2.1
4.48k stars 1.99k forks source link

Sarsa not serializable #1097

Open krflol opened 6 years ago

krflol commented 6 years ago

Unable to save Sarsa machines. Not marked as serializable.

cesarsouza commented 6 years ago

Thanks!

krflol commented 6 years ago

Built the updated commit and tried it. Not sure if working as intended, but may be worth noting in the documentation, when loading a saved (fileName.bin) Sarsa (and likely qlearning, havent tried it), you have to re-instate the explorations policy. IE targetSarsa = Accord.IO.(etc not in front of my computer right now).Load("C:\file\name\);

will give errors. you have to

targetSarsa.ExplorationPolicy = (exploration policy)

I spent a long time messing with that. Assumed the exploration policy saved with it.

cesarsouza commented 6 years ago

Hi @krflol,

Thanks for the update. Do you think you could post a code snippet to help reproduce the issue?

Regards, Cesar

krflol commented 6 years ago

Ill keep it simple. You need two classes to reproduce. I use an initiating class and an adaptation/training class. The initiating class runs the first iteration and saves the network, the second class will persist the network and it can be trained/loaded. Ill make a toy example that should produce the same results or at least give an idea.

Class1

//using statements
//...
//class

public class qlearn : Strategy
{
private double explorationRate = .5
private double learningRate = .8
public override void Initialize()
{
if(State == State.Initialize)
boltzMan = new BoltzmannExploration (explorationRate);
qLearning = new Sarsa(50*50, 3, boltzMan);
qLearning.LearningRate = learningRate;

}

protected override void FirstRunSave()
{
 MemoryStream stream = new MemoryStream();                  
Accord.IO.Serializer.Save(qLearning, stream);
var fileStream = File.Create("C:\\aissd\\qlearncheckpoint.bin");
stream.Seek(0, SeekOrigin.Begin);
stream.CopyTo(fileStream);
fileStream.Close();
}
}

Class2

//using statements
//...
//class

public class qlearnAdapt : Strategy
{
private double explorationRate = 0.5;
private double learningRate = 0.8;
private Sarsa qLearning;

public override void Initialize()
{
if(State == State.Initialize)
boltzMan = new BoltzmannExploration (explorationRate);
qLearning = Accord.IO.Serializer.Load<Sarsa>("C:\\lifealianssd\\qlearncheckpoint.bin");
qLearning.LearningRate = learningRate;
//In my experience this will give errors, you then have to do the following
qLearning.ExplorationPolicy = boltzMan;
}

protected override void qLearnRunSave()
{
//qLearning logic
}
}

Hope that helps. I should emphasize that these are called in two different classes in the same namespace, so the class calling Load overload is relying entirely on the deserialization to get everything going again.

Thanks.