otac0n / markov

Generic Markov chain generation for C#
Other
54 stars 20 forks source link

Feature: Support for saving/loading? #9

Open rccarlson opened 9 months ago

rccarlson commented 9 months ago

Any chance some kind of serialization could be implemented to save/load Markov chains, rather than having to retrain every time the program runs?

otac0n commented 9 months ago

Yes, I can look into that.

otac0n commented 9 months ago

So, I will say that you can use GetStates(), Add(state, next, weight), and AddTerminalState() to do this now, in the hopes that this unblocks you.

rccarlson commented 9 months ago

I ended up implementing a solution with BinaryReader and BinaryWriter:

In MarkovChain.cs:

    public void Serialize(BinaryWriter bw, Action<BinaryWriter, T> tWriter)
    {
        items.Serialize(bw, 
            (bw, chain) => chain.Serialize(bw, tWriter), 
            (bw, dict) => dict.Serialize(bw, tWriter, (bw, i) => bw.Write(i))
            );
        bw.Write(order);
        terminals.Serialize(bw,
            (bw, chain) => chain.Serialize(bw, tWriter),
            (bw, i) => bw.Write(i)
            );
        bw.Write(trainingSize);
    }
    public static MarkovChain<T> Deserialize(BinaryReader br, Func<BinaryReader, T> tReader)
    {
        var items = Utility.ReadDictionary(br,
            br => ChainState<T>.Deserialize(br, tReader),
            br => Utility.ReadDictionary(br, tReader, b => b.ReadInt32())
            );
        var order = br.ReadInt32();
        var terminals = Utility.ReadDictionary(br,
            br => ChainState<T>.Deserialize(br, tReader),
            br => br.ReadInt32());
        var trainingSize = br.ReadInt32();

        return new MarkovChain<T>(items, order, terminals, trainingSize);
    }

In ChainState.cs:

    public void Serialize(BinaryWriter bw, Action<BinaryWriter, T> tWriter)
    {
        bw.Write(items.Length);
        foreach (var item in items)
            tWriter(bw, item);
    }
    public static ChainState<T> Deserialize(BinaryReader br, Func<BinaryReader, T> tReader)
    {
        var len = br.ReadInt32();
        var items = new T[len];
        for(int i = 0; i < len; i++)
        {
            items[i] = tReader(br);
        }
        return new ChainState<T>(items);
    }

In my Utility class:

    public static void Serialize<TKey, TValue>(this Dictionary<TKey, TValue> dict, BinaryWriter bw,
        Action<BinaryWriter, TKey> keySerializer, Action<BinaryWriter, TValue> valueSerializer)
        where TKey : notnull
        where TValue : notnull
    {
        bw.Write(dict.Count);
        foreach(var (key,value) in dict)
        {
            keySerializer(bw, key);
            valueSerializer(bw, value);
        }
    }
    public static Dictionary<TKey, TValue> ReadDictionary<TKey, TValue>(BinaryReader br,
        Func<BinaryReader, TKey> readKey, Func<BinaryReader, TValue> readValue)
        where TKey : notnull
        where TValue : notnull
    {
        var len = br.ReadInt32();
        Dictionary<TKey, TValue> dict = new();
        for (int i = 0;i < len;i++)
        {
            dict.Add(readKey(br), readValue(br));
        }
        return dict;
    }

It's a little goofy, but it gets the job done.

TomArrow commented 5 months ago

So, I will say that you can use GetStates(), Add(state, next, weight), and AddTerminalState() to do this now, in the hopes that this unblocks you.

Pardon me, where can I find AddTerminalState()? I only see AddTerminalInternal