asieradzk / RL_Matrix

Deep Reinforcement Learning in C#
Other
165 stars 7 forks source link

Reinforcement Learning The Identity Function #9

Closed ds5678 closed 7 months ago

ds5678 commented 7 months ago

Context

I was having difficulty using reinforcement learning for more complex problems, so I made a simple project to test the learning potential. In the code below, I am trying to teach the Boolean identity function to the agent, but the agent gives random answers before and after training.

Code

using OneOf;
using RLMatrix; //NuGet v0.2.0

namespace ReinforcementLearning.Trivial;

internal class Program
{
    static void Main(string[] args)
    {
        var options = new PPOAgentOptions(lr: .005f);
        List<IEnvironment<float[]>> environments = [new TrivialEnvironment(), new TrivialEnvironment(), new TrivialEnvironment()];
        var agent = new PPOAgent<float[]>(options, environments, new PPONetProviderBase<float[]>(16, 2));

        Console.WriteLine($"Pre-training: {Evaluate(agent)}");

        for (int i = 0; i < 10000; i++)
        {
            Console.Write(i);
            agent.Step();
            Console.Write("\b\b\b\b\b\b\b\b\b\b\b\b");
        }

        Console.WriteLine($"Post-training: {Evaluate(agent)}");
    }

    private static float Evaluate(PPOAgent<float[]> agent, int count = 1000)
    {
        float[] state = new float[1];
        int correct = 0;
        for (int i = 0; i < count; i++)
        {
            state[0] = TrivialEnvironment.RandomValue();
            (int[] greedyDiscreteActions, float[] meanContinuousActions) = agent.SelectAction(state, false);
            if (state[0] == greedyDiscreteActions[0])
            {
                correct++;
            }
        }
        return (float)correct / count;
    }

    private sealed class TrivialEnvironment : IEnvironment<float[]>
    {
        public const float CorrectAnswerReward = 1;
        public const float WrongAnswerPenalty = -1;

        private readonly float[] state = [RandomValue()];

        public static int RandomValue()
        {
            return Random.Shared.Next(2);
        }

        public int stepCounter { get; set; }
        public int maxSteps { get; set; } = int.MinValue;
        public bool isDone { get; set; }
        public OneOf<int, (int, int)> stateSize { get; set; } = 1;
        public int[] actionSize { get; set; } = [2];

        public float[] GetCurrentState() => state;

        public TrivialEnvironment() => Initialise();

        public void Initialise() => Reset();

        public void Reset()
        {
            stepCounter = 0;
            isDone = false;
        }

        public float Step(int[] actionsIds)
        {
            float input = state[0];
            float output = actionsIds[0];

            state[0] = RandomValue();

            return input == output ? CorrectAnswerReward : WrongAnswerPenalty;
        }
    }
}

Am I using the library incorrectly?

asieradzk commented 7 months ago

Hey. There are couple issues with how you use the library:

  1. You should end episode at some point, max steps should be something between 1 and (not too large) int. For infinite episodes you must learn how to spoof them as finite episodes. Your agent never learns because episodes is never finished you should set the "done" flag to true Reset will be called

  2. For ppo best to use some batch size

  3. private readonly float[] state = [RandomValue()]; means that network and your "Step" method will see different state so reward is assigned randomly, better assign it in reset or per step basis

  4. Having "0f" value for observations can lead to issues better to start at 1

  5. Im not sure your "trivial" env is very trivial you single-step episodes and it may take a lot of hyperparamter tuning for the network to get something useful out of signals.

I gave it a go and have some issues - maybe I broke something without realising. Give me a day

asieradzk commented 7 months ago

Here, this is the correct way to set up the env, with the tips above.

Be sure to grab newest nuget: 0.2.1 - I may or may not have broken some things in previous version :)


    public class TrivialEnvironment : IEnvironment<float[]>
    {
        public const float CorrectAnswerReward = 1;
        public const float WrongAnswerPenalty = -1;

        public float[] state;

        public static int RandomValue()
        {
            return Random.Shared.Next(2);
        }

        public int stepCounter { get; set; }
        public int maxSteps { get; set; } = 10;
        public bool isDone { get; set; }
        public OneOf<int, (int, int)> stateSize { get; set; } = 1;
        public int[] actionSize { get; set; } = [2];

        public float[] GetCurrentState() => state;

        public TrivialEnvironment() => Initialise();

        public void Initialise() => Reset();

        public void Reset()
        {
            state = new float[1] { RandomValue() };
            stepCounter = 0;
            isDone = false;
        }

        public float Step(int[] actionsIds)
        {
            float input = state[0];
            float output = actionsIds[0];

            state[0] = RandomValue();

            if (stepCounter++ >= maxSteps)
            {
                isDone = true;
            }

            return input == output ? CorrectAnswerReward : WrongAnswerPenalty;
        }
    }