jmacglashan / burlap

Repository for the ongoing development of the Brown-UMBC Reinforcement Learning And Planning (BURLAP) java library
Apache License 2.0
274 stars 162 forks source link

Q Learning Stochastic Movements #91

Closed onaclov2000 closed 8 years ago

onaclov2000 commented 8 years ago

I'm currently using Burlap for a project for my ML class, I am attempting to make the movements have a probability of not 1, however I don't see anything in the source that shows how to do that. I have set the

GridWorldDomain x = new GridworldDomain(9,9);
x = setProbSucceedTransitionDynamics(.8)

however when I try something like say .1 I would expect the "directions" the Q learning algorithm to discover are abnormal to the normal person, but completely make sense when your movements are nearly implied opposite of intended.

The last bit may be confusing so here is an example that hopefully clarifies.

Suppose you are in the bottom left corner. If you want to go UP and your transition dynamics are .1, this means that going up has a higher probability of ending up either down left (both remain stationary really), or right. The logical move would be to select right, this means that if you go down you stay stationary, go left you stay stationary, and you have a higher probability of going up than right.

I was hoping to experiment with a Q learner which would discover that moving right would result in UP (or rather moving east has a better chance of up, etc).

Looking in the code I don't see any references to transition dynamics in the Q learning class, but I could just be misunderstanding what I'm seeing.

jmacglashan commented 8 years ago

Transition dynamics are always encoded with the Action object (each Action is responsible for its own transitions) through its performActionHelper (for sampling a transition) and getTranstions method, if it implements the FullActionModel.

Q-Learning doesn't reference the transition dynamics because Q-learning is a learning algorithm, which means the agent doesn't get to know the transition dynamics, it only gets to observe the outcomes of its actions in an environment. (And if the environment is a SimulatedEnvironment, then it determines the transitions using the transition dynamics specified in the Action.)

As far as GridWorldDomain goes, it provides some helper methods for common ways of adding stochasticity in a grid world without requiring the client to wholly reimplement the Action objects. One of them is indeed setProbSucceedTransitionDynamics, which has the effect you described. Another is setTransitionDynamics, which lets you specify the movement probabilities in all directions for each of the standard NSEW movement directions in a grid world.

If you want to verify the effect of the probabilities that you're setting in an interactive way, I recommend using a VisualExplorer or TerminalExplorer. Or, you can always just sample the the actions in a state yourself and see what outcome states you're getting.

I'll note that learning this in Q-learning will probably require a lot of data and a small learning rate because Q-learning is a sample based model-free learner and the actions will be so noisy it will be difficult for it to converge on the corresponding value functions.

I'm going to close, since I don't think this is an issue in the code.

onaclov2000 commented 8 years ago

That's totally fine that you closed it, I wasn't sure the "best" place to ask how this works, I wouldn't be surprised if the results were quite painful and the learner had a hard time, but that was kind of intentional for my experiments.

I am using

    Visualizer v = GridWorldVisualizer.getVisualizer(gwdg.getMap());
    new EpisodeSequenceVisualizer(v, domain, outputpath);

for visualization.

Just to clarify, Q learning can work with stochastic movements, however I need to use setTransitionDynamics?

I have been trying it with setProbSucceedTransitionDynamics but am not seeing the expected "confused movements".

jmacglashan commented 8 years ago

In general, if you have questions, the google groups page is a good place to ask and if there is a bug or feature request, the github issues is the place to put it. That said, I can answer your questions here for the moment.

Yes, Q-learning works with stochastic domains and setProbSucceedTransitionDynamics is one valid way to set stochasticity for GridWorld's specifically. (Though you should make sure you do it before you generate your domain, otherwise the generated domain won't have the settings you specified).

EpisodeSequenceVisualizer will let you review observed episodes and is good for that. However, what I meant was using VisualExplorer or TerminalExplorer, which are classes that let you act as the agent choosing actions in an environment (if you just provide them a domain rather than an environment in the constructor, they automatically create a SimulatedEnvironment from the domain). These are really helpful tools for making sure your domain is specified as you want before you run a planning or learning algorithm on it.

Here is some example code in your setting that will give you a visual explorer as well as show you the exact probabilities (as pulled from the getTransitions) and sampled probabilities (as pulled from performActionHelper). In all cases, you should find that the specified probability are as expected.

public static void main(String[] args) {

        GridWorldDomain gwd = new GridWorldDomain(11, 11);
        gwd.setMapToFourRooms();
        gwd.setProbSucceedTransitionDynamics(0.1); //add very high stochasticity
        Domain domain = gwd.generateDomain();
        State s = GridWorldDomain.getOneAgentNoLocationState(domain, 0, 0);

        //get exact transition probabilities for east
        GroundedAction a = domain.getAction("east").getAssociatedGroundedAction();
        List<TransitionProbability> tps = a.getTransitions(s);
        for(TransitionProbability tp : tps){
            System.out.println(direction(s, tp.s) + ": " + tp.p);
        }

        System.out.println("--");

        //get sample based transition probabilities (which is what an Environment will use)
        HashedAggregator<String> ag = new HashedAggregator<String>();
        for(int i = 0; i < 10000; i++){
            ag.add(direction(s, a.executeIn(s)), 1.);
        }
        for(Map.Entry<String, Double> e : ag.entrySet()){
            System.out.println(e.getKey() + ": " + e.getValue() / 10000);
        }

        VisualExplorer exp = new VisualExplorer(domain, GridWorldVisualizer.getVisualizer(gwd.getMap()), s);
        exp.addKeyAction("w", "north");
        exp.addKeyAction("s", "south");
        exp.addKeyAction("d", "east");
        exp.addKeyAction("a", "west");

        exp.initGUI();
    }

    /*
     * Returns a string indicating whether the agent moved north, south, east, west, or stayed between s1 and s2
     */
    public static String direction(State s1, State s2){
        ObjectInstance a1 = s1.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT);
        ObjectInstance a2 = s2.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT);

        int x1 = a1.getIntValForAttribute("x");
        int y1 = a1.getIntValForAttribute("y");

        int x2 = a2.getIntValForAttribute("x");
        int y2 = a2.getIntValForAttribute("y");

        if(x2-x1 > 0){
            return "east";
        }
        if(x2-x1 < 0){
            return "west";
        }
        if(y2-y1 > 0){
            return "north";
        }
        if(y2-y1 < 0){
            return "south";
        }

        return "stay";
    }

For the visual explorer, make sure your click on the image first to give it keyboard focus if you want to use the wsad keys (or you can manually enter the names of the actions in the text field at the bottom and then hit the execute button).

onaclov2000 commented 8 years ago

Sorry for being dense, I see how that would work with the explorer world that you manually control, but for Q Learning specifically, I have set the transition dynamics via

    gwdg.setProbSucceedTransitionDynamics(transition);
    domain = gwdg.generateDomain();

When I run the Q Learner I get results that look like this :+1:

and Here is my full class (sorry it's kinda a mess)

public class AdvancedBehavior {

GridWorldDomain gwdg;
Domain domain;
    GridWorldRewardFunction rf;
    GridWorldTerminalFunction tf;
StateConditionTest goalCondition;
State initialState;
HashableStateFactory hashingFactory;
Environment env;

public AdvancedBehavior(int[][] map, double transition, double movementReward){

    gwdg = new GridWorldDomain(map.length,map[0].length); 
    gwdg.setMap(map);
            gwdg.setProbSucceedTransitionDynamics(transition);
    domain = gwdg.generateDomain();

    rf = new GridWorldRewardFunction(domain, movementReward);
    rf.setReward(map.length-1,map[0].length-1,1.0); // S/B top right corner
    rf.setReward(map.length-1,map[0].length-2,-1.0);
    tf = new GridWorldTerminalFunction();
    tf.markAsTerminalPosition(map.length-1,map[0].length - 1);
    tf.markAsTerminalPosition(map.length-1,map[0].length - 2);

    goalCondition = new TFGoalCondition(tf);

    initialState = GridWorldDomain.getOneAgentNLocationState(domain, 1);
    GridWorldDomain.setAgent(initialState, 0, 0);
    GridWorldDomain.setLocation(initialState, 0, map.length-1,map[0].length - 1);

    hashingFactory = new SimpleHashableStateFactory();

    env = new SimulatedEnvironment(domain, rf, tf, initialState);

}

public void visualize(String outputpath){
    Visualizer v = GridWorldVisualizer.getVisualizer(gwdg.getMap());
    new EpisodeSequenceVisualizer(v, domain, outputpath);
}

public void manualValueFunctionVis(ValueFunction valueFunction, Policy p, String title){

    List<State> allStates = StateReachability.getReachableStates(initialState, 
                                (SADomain)domain, hashingFactory);

    //define color function
    LandmarkColorBlendInterpolation rb = new LandmarkColorBlendInterpolation();
    rb.addNextLandMark(0., Color.RED);
    rb.addNextLandMark(1., Color.BLUE);

    //define a 2D painter of state values, specifying which attributes correspond 
    //to the x and y coordinates of the canvas
    StateValuePainter2D svp = new StateValuePainter2D(rb);
    svp.setXYAttByObjectClass(GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTX,
            GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTY);

    //create our ValueFunctionVisualizer that paints for all states
    //using the ValueFunction source and the state value painter we defined
    ValueFunctionVisualizerGUI gui = new ValueFunctionVisualizerGUI(allStates, svp, valueFunction);
        gui.setTitle(title);
    //define a policy painter that uses arrow glyphs for each of the grid world actions
    PolicyGlyphPainter2D spp = new PolicyGlyphPainter2D();
    spp.setXYAttByObjectClass(GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTX,
            GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTY);
    spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONNORTH, new ArrowActionGlyph(0));
    spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONSOUTH, new ArrowActionGlyph(1));
    spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONEAST, new ArrowActionGlyph(2));
    spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONWEST, new ArrowActionGlyph(3));
    spp.setRenderStyle(PolicyGlyphPainter2D.PolicyGlyphRenderStyle.DISTSCALED);

    //add our policy renderer to it
    gui.setSpp(spp);
    gui.setPolicy(p);

    //set the background color for places where states are not rendered to grey
    gui.setBgColor(Color.BLACK);

    //start it
    gui.initGUI();

}

public void experimentAndPlotter(){

    //different reward function for more interesting results
    ((SimulatedEnvironment)env).setRf(new GoalBasedRF(this.goalCondition, 5.0, -0.1));

    /**
     * Create factories for Q-learning agent and SARSA agent to compare
     */
    LearningAgentFactory qLearningFactory = new LearningAgentFactory() {
        @Override
        public String getAgentName() {
            return "Q-Learning";
        }

        @Override
        public LearningAgent generateAgent() {
            QLearning q = new QLearning(domain, 0.25, hashingFactory, 1.0, .1);
            q.initializeForPlanning(rf,tf,100);
            return (LearningAgent)q;
        }
    };

    LearningAgentFactory sarsaLearningFactory = new LearningAgentFactory() {
        @Override
        public String getAgentName() {
            return "SARSA";
        }

        @Override
        public LearningAgent generateAgent() {
            return new SarsaLam(domain, 0.99, hashingFactory, 0.0, 1.0, 1.);
        }
    };

    LearningAlgorithmExperimenter exp = new LearningAlgorithmExperimenter(env, 10, 100, qLearningFactory);
    exp.setUpPlottingConfiguration(500, 250, 2, 1000,
            TrialMode.MOSTRECENTANDAVERAGE,
            PerformanceMetric.CUMULATIVESTEPSPEREPISODE,
            PerformanceMetric.AVERAGEEPISODEREWARD,
            PerformanceMetric.STEPSPEREPISODE,
            PerformanceMetric.CUMULATIVEREWARDPERSTEP);

    exp.startExperiment();
    exp.writeStepAndEpisodeDataToCSV("expData");

}

 public static int[][] flip_horizontal(int[][] data) {
    //for each row, swap its contents from left to right
    for (int row = 0; row < data.length; row++) {
        for (int col = 0; col < data[0].length / 2; col++) {
            // given a column: i, its pair is column: width() - i - 1
            // e.g. with a width of 10
            // column 0 is paired with column 9
            // column 1 is paired with column 8 etc.
            int temp = data[row][col];
            data[row][col] = data[row][data[0].length - col - 1];
            data[row][data[0].length - col - 1] = temp;
        }
    }
    return data;

}
  public static int[][] getMap(String path, int x, int y){
    try{
            Scanner scanner = new Scanner(new File(path));
            int [][] tall = new int [x][y];
            int i = 0;
            int j = -1;
            while(scanner.hasNextInt()){

            if (i%x == 0)
            {
              j++;
            }
            tall[i % x][j % y] = scanner.nextInt();

            i++;
            }
            tall = flip_horizontal(tall);
            return tall;
            }
            catch (FileNotFoundException e){
                System.out.println("File Not Found");
            }
            catch(Exception e){
                System.out.println("Other Exception " + e.toString());
            }
                return new int[x][y];
        }
public static void main(String[] args) {

           int[][] map = getMap("./src/main/java/maze_larger.txt", 15, 16);

    AdvancedBehavior example = new AdvancedBehavior(map, .1, -.04);
    String outputPath = "output/";

    example.experimentAndPlotter();
    example.visualize(outputPath);
}

}
jmacglashan commented 8 years ago

It works for both, because Q-learning operates by interacting with an Environment instance and VisualExplorer is also interacting with an Environment (it constructs one for you if you don't specify it). That is, in VisualExplorer, you select an action, that action is passed to the Environment for execution, and the Environment tells VisualExplorer what happens (which is then visualized). In Q-learning, the agent selects an action, passes it to the Environment for execution, and the Environment tells Q-learning what happened (which it uses for learning).

This can perhaps be better expressed in the code I gave you by creating the SimulatedEnvironment and directly sampling it and telling the VisualExplorer to use it:

public static void main(String[] args) {

        GridWorldDomain gwd = new GridWorldDomain(11, 11);
        gwd.setMapToFourRooms();
        gwd.setProbSucceedTransitionDynamics(0.1); //add very high stochasticity
        Domain domain = gwd.generateDomain();
        State s = GridWorldDomain.getOneAgentNoLocationState(domain, 0, 0);
        Environment env = new SimulatedEnvironment(domain, new NullRewardFunction(), new NullTermination(), s);

        //get exact transition probabilities for east
        GroundedAction a = domain.getAction("east").getAssociatedGroundedAction();
        List<TransitionProbability> tps = a.getTransitions(s);
        for(TransitionProbability tp : tps){
            System.out.println(direction(s, tp.s) + ": " + tp.p);
        }

        System.out.println("--");

        //get sample based transition probabilities from executing in our environment like Q-learning will
        HashedAggregator<String> ag = new HashedAggregator<String>();
        for(int i = 0; i < 10000; i++){
            ag.add(direction(s, a.executeIn(env).op), 1.);
            //reset environment back to our corner state for more sampling
            env.resetEnvironment();
        }
        for(Map.Entry<String, Double> e : ag.entrySet()){
            System.out.println(e.getKey() + ": " + e.getValue() / 10000);
        }

        //visual explorer using our environment instance
        VisualExplorer exp = new VisualExplorer(domain, env, GridWorldVisualizer.getVisualizer(gwd.getMap()));
        exp.addKeyAction("w", "north");
        exp.addKeyAction("s", "south");
        exp.addKeyAction("d", "east");
        exp.addKeyAction("a", "west");

        exp.initGUI();
    }

Since I don't have access to your grid files, here is some simple example Q-learning code with prob of success 0.1 where you should indeed see in the episodes that the agent usually doesn't go in the selected direction (it shows that in mine).

public static void main(String[] args) {

        GridWorldDomain gwd = new GridWorldDomain(11, 11);
        gwd.setMapToFourRooms();
        gwd.setProbSucceedTransitionDynamics(0.1);
        Domain domain = gwd.generateDomain();
        RewardFunction rf = new UniformCostRF();
        TerminalFunction tf = new GridWorldTerminalFunction(10, 10);
        State s = GridWorldDomain.getOneAgentNoLocationState(domain, 0, 0);
        SimulatedEnvironment env = new SimulatedEnvironment(domain, rf, tf, s);

        QLearning ql = new QLearning(domain, 0.99, new SimpleHashableStateFactory(), 0., 0.01);

        //do one episode of learning for demonstration
        EpisodeAnalysis ea = ql.runLearningEpisode(env);

        new EpisodeSequenceVisualizer(GridWorldVisualizer.getVisualizer(gwd.getMap()), domain, Arrays.asList(ea));

    }