ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.34k stars 5.64k forks source link

A3C Agent fails with MeanStdFilter #3814

Closed dmadeka closed 4 years ago

dmadeka commented 5 years ago

When I try to run A3C with continuous actions and a MeanStdFilter observation filter. I get the following error:

Which is surprising because Im not using the ConcurrentMeanStdFilter. Does A3C not support the MeanStdFilter?

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-37-dca1e47835d0> in <module>()
     10 # config["num_workers"] = 1
     11 # config["model"] = {"custom_model": "my_model", "custom_preprocessor": "my_prep", "custom_options": {}}
---> 12 agent = a3c.A3CAgent(env='SCOTRL', config=config)

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/agents/agent.py in __init__(self, config, env, logger_creator)
    246             logger_creator = default_logger_creator
    247 
--> 248         Trainable.__init__(self, config, logger_creator)
    249 
    250     @classmethod

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/tune/trainable.py in __init__(self, config, logger_creator)
     86         self._iterations_since_restore = 0
     87         self._restored = False
---> 88         self._setup(copy.deepcopy(self.config))
     89         self._local_ip = ray.services.get_node_ip_address()
     90 

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/agents/agent.py in _setup(self, config)
    316         # TODO(ekl) setting the graph is unnecessary for PyTorch agents
    317         with tf.Graph().as_default():
--> 318             self._init()
    319 
    320     @override(Trainable)

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/agents/a3c/a3c.py in _init(self)
     56 
     57         self.local_evaluator = self.make_local_evaluator(
---> 58             self.env_creator, policy_cls)
     59         self.remote_evaluators = self.make_remote_evaluators(
     60             self.env_creator, policy_cls, self.config["num_workers"])

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/agents/agent.py in make_local_evaluator(self, env_creator, policy_graph)
    436             merge_dicts(self.config, {
    437                 "tf_session_args": self.
--> 438                 config["local_evaluator_tf_session_args"]
    439             }))
    440 

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/agents/agent.py in _make_evaluator(self, cls, env_creator, policy_graph, worker_index, config)
    576             input_creator=input_creator,
    577             input_evaluation_method=config["input_evaluation"],
--> 578             output_creator=output_creator)
    579 
    580     def __getstate__(self):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/evaluation/policy_evaluator.py in __init__(self, env_creator, policy_graph, policy_mapping_fn, policies_to_train, tf_session_creator, batch_steps, batch_mode, episode_horizon, preprocessor_pref, sample_async, compress_observations, num_envs, observation_filter, clip_rewards, clip_actions, env_config, model_config, policy_config, worker_index, monitor_path, log_dir, log_level, callbacks, input_creator, input_evaluation_method, output_creator)
    328                 tf_sess=self.tf_sess,
    329                 clip_actions=clip_actions,
--> 330                 blackhole_outputs=input_evaluation_method == "simulation")
    331             self.sampler.start()
    332         else:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py in __init__(self, env, policies, policy_mapping_fn, preprocessors, obs_filters, clip_rewards, unroll_length, callbacks, horizon, pack, tf_sess, clip_actions, blackhole_outputs)
    104         for _, f in obs_filters.items():
    105             assert getattr(f, "is_concurrent", False), \
--> 106                 "Observation Filter must support concurrent updates."
    107         self.async_vector_env = AsyncVectorEnv.wrap_async(env)
    108         threading.Thread.__init__(self)

AssertionError: Observation Filter must support concurrent updates.
dmadeka commented 5 years ago

Config File:

config = with_common_config({
    # If true, use the Generalized Advantage Estimator (GAE)
    # with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
    # GAE(lambda) parameter
    # Initial coefficient for KL divergence
    # Size of batches collected from each worker
    "sample_batch_size": 200,
    "num_gpus": 16,
    "num_workers": 50,
    # Number of timesteps collected for each SGD round
    "train_batch_size": 4000,
    # Total SGD batch size across all devices for SGD
    # Number of SGD iterations in each outer loop
    # Stepsize of SGD
    "lr": 5e-5,
    # Learning rate schedule
    # Share layers for value function
    # Coefficient of the value function loss
    # Coefficient of the entropy regularizer
    # Clip param for the value function. Note that this is sensitive to the
    # scale of the rewards. If your expected V is large, increase this.
    "batch_mode": "truncate_episodes",
    # Which observation filter to apply to the observation
    "observation_filter": "MeanStdFilter",
    # Uses the sync samples optimizer instead of the multi-gpu one. This does
    # not support minibatches.
    #"simple_optimizer": True,
    # (Deprecated) Use the sampling behavior as of 0.6, which launches extra
    # sampling tasks for performance but can waste a large portion of samples.
    # Use PyTorch as backend - no LSTM support
    # GAE(gamma) parameter
    # Max global norm for each gradient calculated by worker
    # Learning rate
    # Learning rate schedule
    # Value Function Loss coefficient
    # Entropy coefficient
    # Min time per iteration
    # Workers sample async. Note that this increases the effective
    # sample_batch_size by up to 5x due to async buffering of batches.
    "sample_async": True,
})

Function call:

agent = a3c.A3CAgent(env='MyEnv', config=config)
ericl commented 5 years ago

I think it will work if you use "observation_filter": "ConcurrentMeanStdFilter"?

Probably we should choose that automatically when sample_async is True.

dmadeka commented 5 years ago

Got it, there's a separate ConcurrentMeanStdFilter and MeanStdFilter. Got it! Thanks so much!

Not sure why the split btw? Wouldnt you call the appropriate one depending on the algorithm/sample_async?

richardliaw commented 4 years ago

Looks like this is resolved.