rll / rllab

rllab is a framework for developing and evaluating reinforcement learning algorithms, fully compatible with OpenAI Gym.
Other
2.91k stars 799 forks source link

Plot=True not working for the Tensorflow version. #114

Open rksltnl opened 7 years ago

rksltnl commented 7 years ago

Hi,

I'm using the Tensorflow version of rllab in the sandbox. When I set plot=True, the code runs but it doesn't draw the Cartpole plot for me (unlike the example in the Theano version. I checked that the Theano version draws the plot).

Here's the code I ran,

env = TfEnv(normalize(CartpoleEnv()))
policy = GaussianMLPPolicy(
    name="policy",
    env_spec=env.spec,
    # The neural network policy should have two hidden layers, each with 32 hidden units.
    hidden_sizes=(32, 32)
)
baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(
    env=env,
    policy=policy,
    baseline=baseline,
    batch_size=4000,
    max_path_length=100,
    n_itr=40,
    discount=0.99,
    step_size=0.01,
    plot=True,
)
plotter.init_worker()
algo.train()

Here's the error message it prints out. The code still runs regardless. It just doesn't show me the Cartpole.

2017-04-14 21:38:53.613888 PDT | action space: Box(1,)
2017-04-14 21:38:53.614460 PDT | itr #0 | Obtaining samples...
2017-04-14 21:38:53.614543 PDT | itr #0 | Obtaining samples for iteration 0...
Traceback (most recent call last):
  File ".../python_runtime/v2_7/python27.zip.tmp/multiprocessing/queues.py", line 268, in _feed
    send(obj)
  File ".../rllab/sandbox/rocky/tf/core/parameterized.py", line 91, in __getstate__
    d["params"] = self.get_param_values()
  File ".../rllab/sandbox/rocky/tf/core/parameterized.py", line 60, in get_param_values
    param_values = tf.get_default_session().run(params)
AttributeError: 'NoneType' object has no attribute 'run'
2017-04-14 21:38:53.893365 PDT | itr #0 | Processing samples...

Do you have any idea how to fix this?

Thank you.

dementrock commented 7 years ago

The issue is that in the plotter process a tensorflow session needs to be created. I don't have cycles to fix this myself, but it shouldn't be too hard to get it working.

flyingdogs commented 7 years ago

Is there any further instruction for us to fix this bug

dementrock commented 7 years ago

Basically, in TF algorithms the worker processes first run a method to initialize tf session: https://github.com/openai/rllab/blob/d5f09d9b7d7f651678ce8cfc8f661e11ae7dfaf7/sandbox/rocky/tf/samplers/batch_sampler.py#L7

Similar things should happen in the plotting process: https://github.com/openai/rllab/blob/d5f09d9b7d7f651678ce8cfc8f661e11ae7dfaf7/sandbox/rocky/tf/samplers/batch_sampler.py#L7

I suspect wrapping the while loop with a tf.Session call would work.

rksltnl commented 7 years ago

You meant the while loop in here right?

I tried adding with tf.Session() as sess: but it didn't change the behavior.

dementrock commented 7 years ago

Yes that's correct. Where did you add this statement? Can you show me the code here or in a gist file?

rksltnl commented 7 years ago

Hi,

This is my main file (copy&pasted from your launcher). https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-main-py

I added these two lines to create a TF session in the plotter. https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-plotter-py-L26-L27

I added one line to batch_polopt.py to call plotter.init_worker() in start_worker() https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-batch_polopt-py-L92

dementrock commented 7 years ago

I see. You shouldn't call plotter.init_worker() before algo.train(), because the TF session hasn't been initialized there.

rksltnl commented 7 years ago

plotter.init_worker() is inside the start_worker() function which gets called inside algo.train(). So plotter.init_worker() is called after algo.train() is called and TF session is initialized right?

https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-batch_polopt-py-L89-L107

dementrock commented 7 years ago

Sorry I was looking at the earlier code snippet you had: https://github.com/openai/rllab/issues/114#issue-221927819. Is this corrected in the newer version? If so it's really weird. Do you have a longer stack trace?

rksltnl commented 7 years ago

Yeah the corrected version (gist version) doesn't change the behavior at all. I've pasted the console output. It doesn't show any more information. It ignores the error and continues to run. https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-console-output-L167-L176

dementrock commented 7 years ago

So I tried digging into the issue myself. The problem is that the policy gets deserialized before the session was created in the worker process. A way to fix this is to send a pickled version of the policy and unpickle manually on the other end. So instead of queue.put(['update', env, policy]) at https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-plotter-py-L69, you do queue.put(['update', env, pickle.dumps(policy)]), and on the other end you can dopickle.load(policy).

You also need to make sure plotter.init_worker is called before the policy object is created. This is usually done by putting that statement at the top of the file, or use run_experiment_lite. However, I then run into an issue which seems like a race condition between Tensorflow and the rendering thread, which I could not resolve.

I recommend that instead of this, just put a synchronized render statement after each iteration. So in https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-batch_polopt-py-L129, rather than calling self.update_plot(), use rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length). Also remove the lines to initialize worker process at https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-batch_polopt-py-L91.

Finally can you submit a pull request for this, after you get it working?

rksltnl commented 7 years ago

Yes I'll submit the PR for the the solution with episode blocking plots tomorrow.

rksltnl commented 7 years ago

I'm looking into more fundamental fix as well. When you said "on the other hand", did you mean the update case in here? https://gist.github.com/rksltnl/3d082c78995808ddf9649c6be344186c#file-plotter-py-L40

And use it like this?

elif 'update' in msgs: env, policy_pkl = msgs['update'] policy = pickle.loads(policy_pkl)

dementrock commented 7 years ago

yes