Open krflol opened 6 years ago
Thanks!
Built the updated commit and tried it. Not sure if working as intended, but may be worth noting in the documentation, when loading a saved (fileName.bin) Sarsa (and likely qlearning, havent tried it), you have to re-instate the explorations policy. IE targetSarsa = Accord.IO.(etc not in front of my computer right now).Load("C:\file\name\);
will give errors. you have to
targetSarsa.ExplorationPolicy = (exploration policy)
I spent a long time messing with that. Assumed the exploration policy saved with it.
Hi @krflol,
Thanks for the update. Do you think you could post a code snippet to help reproduce the issue?
Regards, Cesar
Ill keep it simple. You need two classes to reproduce. I use an initiating class and an adaptation/training class. The initiating class runs the first iteration and saves the network, the second class will persist the network and it can be trained/loaded. Ill make a toy example that should produce the same results or at least give an idea.
Class1
//using statements
//...
//class
public class qlearn : Strategy
{
private double explorationRate = .5
private double learningRate = .8
public override void Initialize()
{
if(State == State.Initialize)
boltzMan = new BoltzmannExploration (explorationRate);
qLearning = new Sarsa(50*50, 3, boltzMan);
qLearning.LearningRate = learningRate;
}
protected override void FirstRunSave()
{
MemoryStream stream = new MemoryStream();
Accord.IO.Serializer.Save(qLearning, stream);
var fileStream = File.Create("C:\\aissd\\qlearncheckpoint.bin");
stream.Seek(0, SeekOrigin.Begin);
stream.CopyTo(fileStream);
fileStream.Close();
}
}
Class2
//using statements
//...
//class
public class qlearnAdapt : Strategy
{
private double explorationRate = 0.5;
private double learningRate = 0.8;
private Sarsa qLearning;
public override void Initialize()
{
if(State == State.Initialize)
boltzMan = new BoltzmannExploration (explorationRate);
qLearning = Accord.IO.Serializer.Load<Sarsa>("C:\\lifealianssd\\qlearncheckpoint.bin");
qLearning.LearningRate = learningRate;
//In my experience this will give errors, you then have to do the following
qLearning.ExplorationPolicy = boltzMan;
}
protected override void qLearnRunSave()
{
//qLearning logic
}
}
Hope that helps. I should emphasize that these are called in two different classes in the same namespace, so the class calling Load overload is relying entirely on the deserialization to get everything going again.
Thanks.
Unable to save Sarsa machines. Not marked as serializable.