Closed ethanluoyc closed 2 years ago
Hi, the only scenario that comes to my mind is that Learner's step samples more than once from the iterator. Is that the case?
@qstanczyk I don’t think I sampled more than once though. The deadlock also do not happen at a fixed time. Sometimes it happens sooner and sometimes later.
Can you check exactly where agent is getting stuck? Is it on next() call on the iterator that LocalLayout controls? If so, can you debug what happens in _has_data_for_training method? It must return true
while looks like iterator doesn't have more data to sample.
@qstanczyk
OK, I think I hit a case where the iterator is not ready but there are still data for sampling. Not sure why this can happen tho.
def _has_data_for_training(self):
if self._iterator.ready():
return True
for (table, batch_size) in zip(self._replay_tables,
self._batch_size_upper_bounds):
if not table.can_sample(batch_size):
return False
return True
def update(self):
# super().update()
if self._iterator:
# Perform learner steps as long as iterator has data.
update_actor = False
while self._has_data_for_training():
# Run learner steps (usually means gradient steps).
iterator_ready = self._iterator.ready()
table_can_sample = True
t = self._replay_tables[0]
bsz_ub = self._batch_size_upper_bounds[0]
table_can_sample = t.can_sample(bsz_ub)
print(iterator_ready, table_can_sample, bsz_ub)
# on blocking gets [False True 261], I am using a batch size of 256
self._learner_steps += 1
self._batch_size_upper_bounds = [
math.ceil(t.info.rate_limiter_info.sample_stats.completed /
self._learner_steps) for t in self._replay_tables
]
self._learner.step()
update_actor = True
if update_actor:
# Update the actor weights only when learner was updated.
self._actor.update()
return
This is expected (and that is why _has_data_for_training
checks both conditions). It can happen when iterator didn't fetch the data from the table yet, but Reverb table has data to be sampled. In such case call to next
will block for some time until data is fetched from the table (but it should not hang).
But it's hanging in my case somehow (or maybe just blocking for a long time). Are there some reverb stats that I can look into to understand what's going on?
Small change to verify that iterator is accessed correctly. Can you try it and see is assertion is hit?
I just tested. I do not hit the assert but still appears deadlock.
Let me grab some statistics from the rate limiter to see what's going on.
I added the following in update
def update(self):
if self._iterator:
# Perform learner steps as long as iterator has data.
update_actor = False
while self._has_data_for_training():
## Added for DEBUG
while not self._iterator.ready():
t = self._replay_tables[0]
num_inserts = t.info.rate_limiter_info.insert_stats.completed
num_samples = t.info.rate_limiter_info.sample_stats.completed
samples_per_insert = t.info.rate_limiter_info.samples_per_insert
# min_size_to_sample = t.info.rate_limiter_info.min_size_to_sample
print(
t.info.rate_limiter_info.min_diff,
num_inserts * samples_per_insert - num_samples,
t.info.rate_limiter_info.max_diff,
t.can_sample(self._batch_size_upper_bounds[0]),
)
import time
time.sleep(0.01)
## END DEBUG
# Run learner steps (usually means gradient steps).
batches_processed = self._iterator.retrieved_elements()
self._learner.step()
assert self._iterator.retrieved_elements() == batches_processed + 1, (
'Learner step must retrieve exactly one '
'element from the iterator. Otherwise agent can deadlock.')
self._batch_size_upper_bounds = [
math.ceil(t.info.rate_limiter_info.sample_stats.completed /
(batches_processed + 1)) for t in self._replay_tables
]
update_actor = True
if update_actor:
# Update the actor weights only when learner was updated.
self._actor.update()
return
At blocking time I get:
230400.0 230400.0 1.7976931348623157e+308 False
I did some calculation by hand and the numbers seem to match my expectation.
I use SPI 128.0 and error tolerance of 0.1. The min_replay_size is 2000. With that, the min_diff
should be 2000 * (1-0.1) * 128 = 230400
. The max_diff is changed in the local layout so that seems alright. n_inserts * spi - n_samples
is equal to min_diff in this case so I should actually expect that reverb can be sampled (based on https://github.com/deepmind/reverb/blob/master/reverb/cc/rate_limiter.cc#L112)
The current yet the iterator never gets unblocked.
How iterator is constructed? Do you use multiple workers per iterator?
It's basically the same as all other JAX agents. The agent uses a single worker, same as e.g. JAX D4PG, I also only use a single GPU for training.
Here are parts that are probably relevant.
"""DrQ-v2 builder"""
from typing import Callable, Iterator, List, Optional
from acme import adders
from acme import core
from acme import datasets
from acme import specs
from acme.adders import reverb as adders_reverb
from acme.agents.jax import builders
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.jax import variable_utils
from acme.utils import counting
from acme.utils import loggers
import jax
import optax
import reverb
from reverb import rate_limiters
from ilax.agents.drq_v2 import acting as acting_lib
from ilax.agents.drq_v2 import config as drq_v2_config
from ilax.agents.drq_v2 import learning as learning_lib
from ilax.agents.drq_v2 import networks as drq_v2_networks
class DrQV2Builder(builders.ActorLearnerBuilder):
"""DrQ-v2 Builder."""
def __init__(
self,
config: drq_v2_config.DrQV2Config,
):
self._config = config
def make_replay_tables(
self, environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
"""Create tables to insert data into."""
samples_per_insert_tolerance = (
self._config.samples_per_insert_tolerance_rate *
self._config.samples_per_insert)
error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
limiter = rate_limiters.SampleToInsertRatio(
min_size_to_sample=self._config.min_replay_size,
samples_per_insert=self._config.samples_per_insert,
error_buffer=error_buffer,
)
replay_table = reverb.Table(
name=self._config.replay_table_name,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
max_size=self._config.max_replay_size,
rate_limiter=limiter,
signature=adders_reverb.NStepTransitionAdder.signature(
environment_spec=environment_spec),
)
return [replay_table]
def make_dataset_iterator(
self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
"""Create a dataset iterator to use for learning/updating the agent."""
dataset = datasets.make_reverb_dataset(
table=self._config.replay_table_name,
server_address=replay_client.server_address,
batch_size=self._config.batch_size,
prefetch_size=self._config.prefetch_size,
)
return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])
def make_adder(self, replay_client: reverb.Client) -> Optional[adders.Adder]:
"""Create an adder which records data generated by the actor/environment.
Args:
replay_client: Reverb Client which points to the replay server.
"""
return adders_reverb.NStepTransitionAdder(
client=replay_client,
n_step=self._config.n_step,
discount=self._config.discount,
)
def make_actor(
self,
random_key: networks_lib.PRNGKey,
policy_network: drq_v2_networks.DrQV2PolicyNetwork,
adder: Optional[adders.Adder] = None,
variable_source: Optional[core.VariableSource] = None) -> core.Actor:
"""Create an actor instance.
Args:
random_key: A key for random number generation.
policy_network: Instance of a policy network; this should be a callable
which takes as input observations and returns actions.
adder: How data is recorded (e.g. added to replay).
variable_source: A source providing the necessary actor parameters.
"""
assert variable_source is not None
device = "cpu"
variable_client = variable_utils.VariableClient(
variable_source, "policy", device=device)
variable_client.update_and_wait()
return acting_lib.DrQV2Actor(
policy_network,
random_key,
variable_client=variable_client,
adder=adder,
backend=device,
)
def make_learner(self,
random_key: networks_lib.PRNGKey,
networks: drq_v2_networks.DrQV2Networks,
dataset: Iterator[reverb.ReplaySample],
logger: Optional[loggers.Logger] = None,
replay_client: Optional[reverb.Client] = None,
counter: Optional[counting.Counter] = None) -> core.Learner:
"""Creates an instance of the learner.
Args:
random_key: A key for random number generation.
networks: struct describing the networks needed by the learner; this can
be specific to the learner in question.
dataset: iterator over samples from replay.
replay_client: client which allows communication with replay, e.g. in
order to update priorities.
counter: a Counter which allows for recording of counts (learner steps,
actor steps, etc.) distributed throughout the agent.
checkpoint: bool controlling whether the learner checkpoints itself.
"""
del replay_client
config = self._config
critic_optimizer = optax.adam(config.learning_rate)
policy_optimizer = optax.adam(config.learning_rate)
encoder_optimizer = optax.adam(config.learning_rate)
sigma_start, sigma_end, sigma_schedule_steps = config.sigma
observations_per_step = int(config.batch_size / config.samples_per_insert)
if hasattr(config, "min_observations"):
min_observations = config.min_observations
else:
min_observations = config.min_replay_size
# Compute the schedule for the learner
# Learner only starts updating after min_observations number of steps
sigma_schedule = lambda step: optax.linear_schedule( # noqa
sigma_start, sigma_end, sigma_schedule_steps)((step + max(
min_observations, config.batch_size)) * observations_per_step)
return learning_lib.DrQV2Learner(
random_key=random_key,
dataset=dataset,
networks=networks,
sigma_schedule=sigma_schedule,
policy_optimizer=policy_optimizer,
critic_optimizer=critic_optimizer,
encoder_optimizer=encoder_optimizer,
augmentation=config.augmentation,
critic_soft_update_rate=config.critic_q_soft_update_rate,
discount=config.discount,
noise_clip=config.noise_clip,
logger=logger,
counter=counter,
)
The agent definition based on locallayout
from typing import Optional
from acme import specs
from acme.jax.layouts import local_layout
from acme.utils import counting
from acme.utils import loggers
import optax
from ilax.agents.drq_v2 import builder
from ilax.agents.drq_v2 import config as drq_v2_config
# from ilax.agents.drq_v2 import local_layout
from ilax.agents.drq_v2 import networks as drq_v2_networks
class DrQV2(local_layout.LocalLayout):
"""Data-regularized Q agent version 2."""
builder: builder.DrQV2Builder
def __init__(
self,
environment_spec: specs.EnvironmentSpec,
networks: drq_v2_networks.DrQV2Networks,
config: drq_v2_config.DrQV2Config,
seed: int,
counter: Optional[counting.Counter] = None,
logger: Optional[loggers.Logger] = None,
):
drq_v2_builder = builder.DrQV2Builder(config)
policy_network = drq_v2_networks.get_default_behavior_policy(
networks, environment_spec.actions,
optax.linear_schedule(*config.sigma))
self.builder = drq_v2_builder
super().__init__(
seed=seed,
environment_spec=environment_spec,
builder=drq_v2_builder,
networks=networks,
policy_network=policy_network,
# min_replay_size=min_replay_size,
batch_size=config.batch_size,
workdir=None,
num_sgd_steps_per_step=1,
learner_logger=logger,
counter=counter,
checkpoint=False)
config
@dataclasses.dataclass
class DrQV2Config:
"""Configuration parameters for DrQ."""
augmentation: augmentations.DataAugmentation = augmentations.batched_random_crop
min_replay_size: int = 2_000
max_replay_size: int = 1_000_000
replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE
prefetch_size: int = 1
discount: float = 0.99
batch_size: int = 256
n_step: int = 3
critic_q_soft_update_rate: float = 0.01
learning_rate: float = 1e-4
noise_clip: float = 0.3
sigma: Tuple[float, float, int] = (1.0, 0.1, 500000)
samples_per_insert: float = 128
samples_per_insert_tolerance_rate: float = 0.1
Then I don't know ;-(
Then I don't know ;-(
Me neither. Thanks a lot for checking this with me. It's really helpful! I have a feeling that there is some problem with the prefetch thread not being scheduled, so although there are samples the prefetching thread is not putting things onto the queue. I don't see how that can possibly happen and there's no way we can control that given it's done by the OS.
@qstanczyk It seems that the problem may have gone away when I change the default transition adder's max_in_flight_items from 5 to 2. I am not sure why this helps tho. I am working on dm_control but apply an action repeat of 2 (which means that the episode length is now 500 instead of 1000). and I use n_steps = 3. Maybe this somehow interferes with the rate limiter's tolerance. I never fully understand some of Reverb's parameters (max_in_flight_items) . Maybe it's because of my combination of setup + single-processing that is causing the issues.
Smaller value of max_in_flight_items reduces the number of elements that can be inserted into Reverb in one actor step. But as rate limiter in case of single-process agent has blocking of inserts disabled, this setting shouldn't affect hangs.
@qstanczyk does the insertion happen in the thread the actor interacts with the environment or is it done in a background thread? In the case writes happen in the agent’s thread it’s possible to deadlock if not all items have been written to reverb right?
I still get the deadlock (from htop all threads are suspended) but I was unable to reproduce this every time. Sometimes it hangs after 10K steps sometimes it does fine even after 1M steps. Is it possible that the problem is caused by the fact that I am using pixel observations (84 x 84 x 9 images).
I have checked over and over again and to me it does seem the prefetching is handled correctly and I don’t see any thing that could go wrong. I wonder if in my case there is some thread starvation going on that’s causing the problem.
Inserts happen in the background, but at the end of the episode there is a flush (which makes sure all pending items are written to Reverb before continuing). However, with non-distributed agent Reverb's rate limiter should not block inserts.
Thanks for the clarification! Sorry for keeping the thread open as I would really like to figure out the problem. If flush only happens at the end of an episode, does that mean that the agent may deadlock in the middle of an episode as not all writes are processed by reverb?
That shouldn't happen either - writers don't block (due to rate limiter setup), while sampling should happen only when there is data in the iterator. That is the theory... but seems like there must be an issue somewhere.
Yeah. There’s either some subtle thing that’s not handled or I will blame my operating system scheduler for not scheduling the threads that should be processing the items:)
Is there any additional thing that you would like me to check? I wanted to create an minimal example but my agent has several components and I haven’t been able to consistently find the minimal setup that triggers this. I noticed that table.info contains workers stats. Is that useful for you to try to figure out what’s going wrong? I can produce some logs from those stats if they are useful. Otherwise, I suppose we can keep the issue open and maybe I will have a more minimal deterministic example at some point.
there are some discussions in the python-dev which might be related. I don’t yet see how much the issue is similar but just posting here in case it’s relevant. https://bugs.python.org/issue46812. I’m wondering if the behavior of the hangs might be similar to the issue discussed in https://github.com/sqlalchemy/sqlalchemy/issues/7679. Maybe it’s worth mentioning that my experiment is spawning a lot of threads (>500). I have no idea why there are so many of them but I am using a computer with 48 CPUs and 2 3080 GPU (only using one for training)
A simple repro I could try would be best. Otherwise it is hard to guess what could be wrong.
@qstanczyk I manage to create a relatively minimal example and have attached below. I also pasted the code here. It consists of a few files. I have created a zip archive deadlock.zip
You can run the example with MUJOCO_GL=egl XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 python run.py
. As I mentioned previously, the blocking doesn't happen deterministically, so it's worthwhile trying a few times. I have seen cases where running for >1M steps works ok, so even reproducing the bug is not that trivial...
drq_v2.py
: An implementation of DrQV2 in JAX. I have consolidated the individual components into a single file.
"""Learner component for DrQV2."""
import dataclasses
from functools import partial
import time
from typing import Iterator, List, NamedTuple, Optional, Callable
from acme import adders
from acme import core
from acme import specs
from acme import types
from acme.adders import reverb as adders_reverb
from acme.jax import networks as networks_lib
from acme.jax import types as jax_types
from acme.jax import utils
from acme.jax import variable_utils
from acme import datasets
from acme.agents.jax import builders
from acme.utils import counting
from acme.utils import loggers
from acme.agents.jax import actor_core
from acme.agents.jax import actors
from reverb import rate_limiters
import jax
import jax.numpy as jnp
import optax
import reverb
import networks as drq_v2_networks
DataAugmentation = Callable[[jax_types.PRNGKey, types.NestedArray],
types.NestedArray]
# From https://github.com/ikostrikov/jax-rl/blob/main/jax_rl/agents/drq/augmentations.py
def random_crop(key: jax_types.PRNGKey, img, padding):
crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
crop_from = jnp.concatenate([crop_from, jnp.zeros((1,), dtype=jnp.int32)])
padded_img = jnp.pad(
img, ((padding, padding), (padding, padding), (0, 0)), mode="edge")
return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)
def batched_random_crop(key, imgs, padding=4):
keys = jax.random.split(key, imgs.shape[0])
return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding)
@dataclasses.dataclass
class DrQV2Config:
"""Configuration parameters for DrQ."""
augmentation: DataAugmentation = batched_random_crop
min_replay_size: int = 2_000
max_replay_size: int = 1_000_000
replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE
prefetch_size: int = 4
discount: float = 0.99
batch_size: int = 256
n_step: int = 3
critic_q_soft_update_rate: float = 0.01
learning_rate: float = 1e-4
noise_clip: float = 0.3
sigma: float = 0.2
samples_per_insert: float = 128.0
samples_per_insert_tolerance_rate: float = 0.1
num_sgd_steps_per_step: int = 1
def _soft_update(
target_params: networks_lib.Params,
online_params: networks_lib.Params,
tau: float,
) -> networks_lib.Params:
"""
Update target network using Polyak-Ruppert Averaging.
"""
return jax.tree_multimap(lambda t, s: (1 - tau) * t + tau * s, target_params,
online_params)
class TrainingState(NamedTuple):
"""Holds training state for the DrQ learner."""
policy_params: networks_lib.Params
policy_opt_state: optax.OptState
encoder_params: networks_lib.Params
# There is not target encoder parameters in v2.
encoder_opt_state: optax.OptState
critic_params: networks_lib.Params
critic_target_params: networks_lib.Params
critic_opt_state: optax.OptState
key: jax_types.PRNGKey
steps: int
class DrQV2Learner(core.Learner):
"""Learner for DrQ-v2"""
def __init__(
self,
random_key: jax_types.PRNGKey,
dataset: Iterator[reverb.ReplaySample],
networks: drq_v2_networks.DrQV2Networks,
sigma_schedule: optax.Schedule,
augmentation: DataAugmentation,
policy_optimizer: optax.GradientTransformation,
critic_optimizer: optax.GradientTransformation,
encoder_optimizer: optax.GradientTransformation,
noise_clip: float = 0.3,
critic_soft_update_rate: float = 0.005,
discount: float = 0.99,
num_sgd_steps_per_step: int = 1,
counter: Optional[counting.Counter] = None,
logger: Optional[loggers.Logger] = None,
):
def critic_loss_fn(
critic_params: networks_lib.Params,
encoder_params: networks_lib.Params,
critic_target_params: networks_lib.Params,
policy_params: networks_lib.Params,
transitions: types.Transition,
key: jax_types.PRNGKey,
sigma: jnp.ndarray,
):
next_encoded = networks.encoder_network.apply(
encoder_params, transitions.next_observation)
next_action = networks.policy_network.apply(policy_params, next_encoded)
next_action = networks.add_policy_noise(next_action, key, sigma,
noise_clip)
next_q1, next_q2 = networks.critic_network.apply(critic_target_params,
next_encoded,
next_action)
# Calculate q target values
next_q = jnp.minimum(next_q1, next_q2)
target_q = transitions.reward + transitions.discount * discount * next_q
target_q = jax.lax.stop_gradient(target_q)
# Calculate predicted Q
encoded = networks.encoder_network.apply(encoder_params,
transitions.observation)
q1, q2 = networks.critic_network.apply(critic_params, encoded,
transitions.action)
loss_critic = (jnp.square(target_q - q1) +
jnp.square(target_q - q2)).mean(axis=0)
return loss_critic, {"q1": q1.mean(), "q2": q2.mean()}
def policy_loss_fn(
policy_params: networks_lib.Params,
critic_params: networks_lib.Params,
encoder_params: networks_lib.Params,
observation: types.Transition,
sigma: jnp.ndarray,
key,
):
encoded = networks.encoder_network.apply(encoder_params, observation)
action = networks.policy_network.apply(policy_params, encoded)
action = networks.add_policy_noise(action, key, sigma, noise_clip)
q1, q2 = networks.critic_network.apply(critic_params, encoded, action)
q = jnp.minimum(q1, q2)
policy_loss = -q.mean()
return policy_loss, {}
policy_grad_fn = jax.value_and_grad(policy_loss_fn, has_aux=True)
critic_grad_fn = jax.value_and_grad(
critic_loss_fn, argnums=(0, 1), has_aux=True)
def update_step(
state: TrainingState,
transitions: types.Transition,
):
key_aug1, key_aug2, key_policy, key_critic, key = jax.random.split(
state.key, 5)
sigma = sigma_schedule(state.steps)
# Perform data augmentation on o_tm1 and o_t
observation_aug = augmentation(key_aug1, transitions.observation)
next_observation_aug = augmentation(key_aug2,
transitions.next_observation)
transitions = transitions._replace(
observation=observation_aug,
next_observation=next_observation_aug,
)
# Update critic
(critic_loss, critic_aux), (critic_grad, encoder_grad) = critic_grad_fn(
state.critic_params,
state.encoder_params,
state.critic_target_params,
state.policy_params,
transitions,
key_critic,
sigma,
)
encoder_update, encoder_opt_state = encoder_optimizer.update(
encoder_grad, state.encoder_opt_state)
critic_update, critic_opt_state = critic_optimizer.update(
critic_grad, state.critic_opt_state)
encoder_params = optax.apply_updates(state.encoder_params, encoder_update)
critic_params = optax.apply_updates(state.critic_params, critic_update)
# Update policy
(policy_loss, policy_aux), actor_grad = policy_grad_fn(
state.policy_params,
critic_params,
encoder_params,
observation_aug,
sigma,
key_policy,
)
policy_update, policy_opt_state = policy_optimizer.update(
actor_grad, state.policy_opt_state)
policy_params = optax.apply_updates(state.policy_params, policy_update)
# Update target parameters
polyak_update_fn = partial(_soft_update, tau=critic_soft_update_rate)
critic_target_params = polyak_update_fn(
state.critic_target_params,
critic_params,
)
metrics = {
"policy_loss": policy_loss,
"critic_loss": critic_loss,
"sigma": sigma,
**critic_aux,
**policy_aux,
}
new_state = TrainingState(
policy_params=policy_params,
policy_opt_state=policy_opt_state,
encoder_params=encoder_params,
encoder_opt_state=encoder_opt_state,
critic_params=critic_params,
critic_target_params=critic_target_params,
critic_opt_state=critic_opt_state,
key=key,
steps=state.steps + 1,
)
return new_state, metrics
self._iterator = dataset
self._counter = counter or counting.Counter()
self._logger = logger or loggers.make_default_logger(
label="learner",
save_data=False,
asynchronous=True,
serialize_fn=utils.fetch_devicearray,
)
self._update_step = utils.process_multiple_batches(update_step,
num_sgd_steps_per_step)
self._update_step = jax.jit(self._update_step)
# Initialize training state
def make_initial_state(key: jax_types.PRNGKey):
key_encoder, key_critic, key_policy, key = jax.random.split(key, 4)
encoder_init_params = networks.encoder_network.init(key_encoder)
encoder_init_opt_state = encoder_optimizer.init(encoder_init_params)
critic_init_params = networks.critic_network.init(key_critic)
critic_init_opt_state = critic_optimizer.init(critic_init_params)
policy_init_params = networks.policy_network.init(key_policy)
policy_init_opt_state = policy_optimizer.init(policy_init_params)
return TrainingState(
policy_params=policy_init_params,
policy_opt_state=policy_init_opt_state,
encoder_params=encoder_init_params,
critic_params=critic_init_params,
critic_target_params=critic_init_params,
encoder_opt_state=encoder_init_opt_state,
critic_opt_state=critic_init_opt_state,
key=key,
steps=0,
)
# Create initial state.
self._state = make_initial_state(random_key)
# Do not record timestamps until after the first learning step is done.
# This is to avoid including the time it takes for actors to come online and
# fill the replay buffer.
self._timestamp = None
def step(self):
# Get the next batch from the replay iterator
sample = next(self._iterator)
transitions = types.Transition(*sample.data)
# Perform a single learner step
self._state, metrics = self._update_step(self._state, transitions)
# Compute elapsed time
timestamp = time.time()
elapsed_time = timestamp - self._timestamp if self._timestamp else 0
self._timestamp = timestamp
# Increment counts and record the current time
counts = self._counter.increment(steps=1, walltime=elapsed_time)
# Attempts to write the logs.
self._logger.write({**metrics, **counts})
def get_variables(self, names):
variables = {
"policy": {
"encoder": self._state.encoder_params,
"policy": self._state.policy_params,
},
}
return [variables[name] for name in names]
def save(self) -> TrainingState:
return self._state
def restore(self, state: TrainingState) -> None:
self._state = state
class DrQV2Builder(builders.ActorLearnerBuilder):
"""DrQ-v2 Builder."""
def __init__(self, config: DrQV2Config):
self._config = config
def make_replay_tables(
self, environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
"""Create tables to insert data into."""
samples_per_insert_tolerance = (
self._config.samples_per_insert_tolerance_rate *
self._config.samples_per_insert)
error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
limiter = rate_limiters.SampleToInsertRatio(
min_size_to_sample=self._config.min_replay_size,
samples_per_insert=self._config.samples_per_insert,
error_buffer=error_buffer,
)
replay_table = reverb.Table(
name=self._config.replay_table_name,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
max_size=self._config.max_replay_size,
rate_limiter=limiter,
signature=adders_reverb.NStepTransitionAdder.signature(
environment_spec=environment_spec),
)
return [replay_table]
def make_dataset_iterator(
self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
"""Create a dataset iterator to use for learning/updating the agent."""
dataset = datasets.make_reverb_dataset(
table=self._config.replay_table_name,
server_address=replay_client.server_address,
batch_size=self._config.batch_size *
self._config.num_sgd_steps_per_step,
prefetch_size=self._config.prefetch_size,
)
iterator = dataset.as_numpy_iterator()
return utils.device_put(iterator, jax.devices()[0])
def make_adder(self, replay_client: reverb.Client) -> Optional[adders.Adder]:
"""Create an adder which records data generated by the actor/environment.
Args:
replay_client: Reverb Client which points to the replay server.
"""
return adders_reverb.NStepTransitionAdder(
client=replay_client,
n_step=self._config.n_step,
discount=self._config.discount,
)
def make_actor(
self,
random_key: networks_lib.PRNGKey,
policy_network: drq_v2_networks.DrQV2PolicyNetwork,
adder: Optional[adders.Adder] = None,
variable_source: Optional[core.VariableSource] = None) -> core.Actor:
"""Create an actor instance.
Args:
random_key: A key for random number generation.
policy_network: Instance of a policy network; this should be a callable
which takes as input observations and returns actions.
adder: How data is recorded (e.g. added to replay).
variable_source: A source providing the necessary actor parameters.
"""
assert variable_source is not None
variable_client = variable_utils.VariableClient(
variable_source, "policy", device='cpu')
variable_client.update_and_wait()
return actors.GenericActor(
actor_core.batched_feed_forward_to_actor_core(policy_network),
random_key=random_key,
variable_client=variable_client,
adder=adder,
backend='cpu')
def make_learner(self,
random_key: networks_lib.PRNGKey,
networks: drq_v2_networks.DrQV2Networks,
dataset: Iterator[reverb.ReplaySample],
logger: Optional[loggers.Logger] = None,
replay_client: Optional[reverb.Client] = None,
counter: Optional[counting.Counter] = None) -> core.Learner:
"""Creates an instance of the learner.
Args:
random_key: A key for random number generation.
networks: struct describing the networks needed by the learner; this can
be specific to the learner in question.
dataset: iterator over samples from replay.
replay_client: client which allows communication with replay, e.g. in
order to update priorities.
counter: a Counter which allows for recording of counts (learner steps,
actor steps, etc.) distributed throughout the agent.
checkpoint: bool controlling whether the learner checkpoints itself.
"""
del replay_client
config = self._config
critic_optimizer = optax.adam(config.learning_rate)
policy_optimizer = optax.adam(config.learning_rate)
encoder_optimizer = optax.adam(config.learning_rate)
return DrQV2Learner(
random_key=random_key,
dataset=dataset,
networks=networks,
sigma_schedule=optax.constant_schedule(config.sigma),
policy_optimizer=policy_optimizer,
critic_optimizer=critic_optimizer,
encoder_optimizer=encoder_optimizer,
augmentation=config.augmentation,
critic_soft_update_rate=config.critic_q_soft_update_rate,
discount=config.discount,
noise_clip=config.noise_clip,
num_sgd_steps_per_step=config.num_sgd_steps_per_step,
logger=logger,
counter=counter,
)
networks.py
: includes the networks used by the agent as well as policy.
"""Network definitions for DrQ-v2."""
import dataclasses
from typing import Callable, Optional, Union
from acme import specs
from acme import types
from acme.agents.jax import actor_core
from acme.jax import networks as networks_lib
from acme.jax import utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax
# Unlike standard FF-policy, in our DrQ-V2 implementation we use
# scheduled stddev parameters, the pure function for the policy
# thus needs to know the current time step of the actor to calculate
# the current stddev.
_Step = int
DrQV2PolicyNetwork = Callable[
[networks_lib.Params, networks_lib.PRNGKey, types.NestedArray, _Step],
types.NestedArray]
class Encoder(hk.Module):
"""Encoder used by DrQ-v2."""
def __call__(self, x):
# Floatify the image.
x = x.astype(jnp.float32) / 255.0 - 0.5
conv_kwargs = dict(
kernel_shape=3,
output_channels=32,
padding="VALID",
# This follows from the reference implementation, the scale accounts for
# using the ReLU activation.
w_init=hk.initializers.Orthogonal(jnp.sqrt(2.0)),
)
return hk.Sequential([
hk.Conv2D(stride=2, **conv_kwargs),
jax.nn.relu,
hk.Conv2D(stride=1, **conv_kwargs),
jax.nn.relu,
hk.Conv2D(stride=1, **conv_kwargs),
jax.nn.relu,
hk.Conv2D(stride=1, **conv_kwargs),
jax.nn.relu,
hk.Flatten(),
])(
x)
class Actor(hk.Module):
"""Policy network used by DrQ-v2."""
def __init__(
self,
action_size: int,
latent_size: int = 50,
hidden_size: int = 1024,
name: Optional[str] = None,
):
super().__init__(name=name)
self.latent_size = latent_size
self.action_size = action_size
self.hidden_size = hidden_size
w_init = hk.initializers.Orthogonal(1.0)
self._trunk = hk.Sequential([
hk.Linear(self.latent_size, w_init=w_init),
hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
jnp.tanh,
])
self._head = hk.Sequential([
hk.Linear(self.hidden_size, w_init=w_init),
jax.nn.relu,
hk.Linear(self.hidden_size, w_init=w_init),
jax.nn.relu,
hk.Linear(self.action_size, w_init=w_init),
# tanh is used to squash the actions into the canonical space.
jnp.tanh,
])
def compute_features(self, inputs):
return self._trunk(inputs)
def __call__(self, inputs):
# Use orthogonal init
# https://github.com/facebookresearch/drqv2/blob/21e9048bf59e15f1018b49b850f727ed7b1e210d/utils.py#L54
h = self.compute_features(inputs)
mu = self._head(h)
return mu
class Critic(hk.Module):
"""Single Critic network used by DrQ-v2."""
def __init__(self, hidden_size: int = 1024, name: Optional[str] = None):
super().__init__(name)
self.hidden_size = hidden_size
def __call__(self, observation, action):
inputs = jnp.concatenate([observation, action], axis=-1)
# Use orthogonal init
# https://github.com/facebookresearch/drqv2/blob/21e9048bf59e15f1018b49b850f727ed7b1e210d/utils.py#L54
q_value = hk.nets.MLP(
output_sizes=(self.hidden_size, self.hidden_size, 1),
w_init=hk.initializers.Orthogonal(1.0),
activate_final=False,
)(inputs).squeeze(-1)
return q_value
class DoubleCritic(hk.Module):
"""Twin critic network used by DrQ-v2.
This is simply two identical Critic module.
"""
def __init__(self, latent_size: int = 50, hidden_size: int = 1024, name=None):
super().__init__(name)
self.hidden_size = hidden_size
self.latent_size = latent_size
self._trunk = hk.Sequential([
hk.Linear(self.latent_size, w_init=hk.initializers.Orthogonal(1.0)),
hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
jnp.tanh,
])
self._critic1 = Critic(self.hidden_size, name="critic1")
self._critic2 = Critic(self.hidden_size, name="critic2")
def compute_features(self, inputs):
return self._trunk(inputs)
def __call__(self, observation, action):
# Use orthogonal init
# https://github.com/facebookresearch/drqv2/blob/21e9048bf59e15f1018b49b850f727ed7b1e210d/utils.py#L54
# The trunk is shared between the twin critics
h = self.compute_features(observation)
return self._critic1(h, action), self._critic2(h, action)
@dataclasses.dataclass
class DrQV2Networks:
encoder_network: networks_lib.FeedForwardNetwork
policy_network: networks_lib.FeedForwardNetwork
critic_network: networks_lib.FeedForwardNetwork
add_policy_noise: Callable[
[types.NestedArray, networks_lib.PRNGKey, float, float],
types.NestedArray]
def get_default_behavior_policy(
networks: DrQV2Networks,
action_specs: specs.BoundedArray,
sigma,
) -> DrQV2PolicyNetwork:
def behavior_policy(
params: networks_lib.Params,
key: networks_lib.PRNGKey,
observation: types.NestedArray,
):
feature_map = networks.encoder_network.apply(params["encoder"], observation)
action = networks.policy_network.apply(params["policy"], feature_map)
noise = jax.random.normal(key, shape=action.shape) * sigma
noisy_action = jnp.clip(action + noise, action_specs.minimum,
action_specs.maximum)
return noisy_action
return behavior_policy
def make_networks(spec: specs.EnvironmentSpec,
hidden_size: int = 1024,
latent_size: int = 50) -> DrQV2Networks:
"""Create networks for the DrQ-v2 agent."""
action_size = onp.prod(spec.actions.shape, dtype=int)
def add_policy_noise(
action: types.NestedArray,
key: networks_lib.PRNGKey,
sigma: float,
noise_clip: float,
) -> types.NestedArray:
"""Adds action noise to bootstrapped Q-value estimate in critic loss."""
noise = jax.random.normal(key=key, shape=spec.actions.shape) * sigma
noise = jnp.clip(noise, -noise_clip, noise_clip)
return jnp.clip(action + noise, spec.actions.minimum, spec.actions.maximum)
def _critic_fn(x, a):
return DoubleCritic(
latent_size=latent_size,
hidden_size=hidden_size,
)(x, a)
def _policy_fn(x):
return Actor(
action_size=action_size,
latent_size=latent_size,
hidden_size=hidden_size,
)(x)
def _encoder_fn(x):
return Encoder()(x)
policy = hk.without_apply_rng(hk.transform(_policy_fn, apply_rng=True))
critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True))
encoder = hk.without_apply_rng(hk.transform(_encoder_fn, apply_rng=True))
# policy_feature = hk.without_apply_rng(
# hk.transform(_policy_features_fn, apply_rng=True))
dummy_action = utils.zeros_like(spec.actions)
dummy_obs = utils.zeros_like(spec.observations)
dummy_action = utils.add_batch_dim(dummy_action)
dummy_obs = utils.add_batch_dim(dummy_obs)
dummy_encoded = hk.testing.transform_and_run(
_encoder_fn, seed=0, jax_transform=jax.jit)(
dummy_obs)
return DrQV2Networks(
encoder_network=networks_lib.FeedForwardNetwork(
lambda key: encoder.init(key, dummy_obs), encoder.apply),
policy_network=networks_lib.FeedForwardNetwork(
lambda key: policy.init(key, dummy_encoded), policy.apply),
critic_network=networks_lib.FeedForwardNetwork(
lambda key: critic.init(key, dummy_encoded, dummy_action),
critic.apply),
add_policy_noise=add_policy_noise)
run.py
is the training script for running the experiment.
from absl import app
from acme import specs
from acme import wrappers
from acme.jax import experiments
from acme.utils import loggers
from acme.wrappers import mujoco
from dm_control import suite
import dm_env
import drq_v2
import jax
import networks as networks_lib
import tensorflow as tf
def make_experiment_logger(label, steps_key, task_instance=0):
del task_instance
return loggers.make_default_logger(
label, save_data=False, steps_key=steps_key)
def make_environment(domain: str,
task: str,
seed=None,
from_pixels: bool = False,
num_action_repeats: int = 1,
frames_to_stack: int = 0,
camera_id: int = 0) -> dm_env.Environment:
"""Create a dm_control suite environment."""
environment = suite.load(domain, task, task_kwargs={"random": seed})
if from_pixels:
environment = mujoco.MujocoPixelWrapper(environment, camera_id=camera_id)
else:
environment = wrappers.ConcatObservationWrapper(environment)
if num_action_repeats > 1:
environment = wrappers.ActionRepeatWrapper(environment, num_action_repeats)
if frames_to_stack > 0:
assert from_pixels, "frame stack for state not supported"
environment = wrappers.FrameStackingWrapper(
environment, frames_to_stack, flatten=True)
environment = wrappers.CanonicalSpecWrapper(environment, clip=True)
environment = wrappers.SinglePrecisionWrapper(environment)
return environment
def main(_):
tf.config.set_visible_devices([], 'GPU')
environment_factory = lambda seed: make_environment(
domain='cheetah',
task='run',
seed=seed,
from_pixels=True,
num_action_repeats=2,
frames_to_stack=3,
camera_id=0)
num_steps = int(1.5e6)
environment = environment_factory(0)
environment_spec = specs.make_environment_spec(environment)
network_factory = networks_lib.make_networks
drq_config = drq_v2.DrQV2Config()
policy_factory = lambda n: networks_lib.get_default_behavior_policy(
n, environment_spec.actions, drq_config.sigma)
eval_policy_factory = lambda n: networks_lib.get_default_behavior_policy(
n, environment_spec.actions, 0.0)
# Construct the agent.
builder = drq_v2.DrQV2Builder(drq_config)
experiment = experiments.Config(
builder=builder,
network_factory=network_factory,
policy_network_factory=policy_factory,
environment_factory=environment_factory,
eval_policy_network_factory=eval_policy_factory,
environment_spec=environment_spec,
observers=(),
seed=0,
logger_factory=make_experiment_logger,
max_number_of_steps=num_steps)
experiments.run_experiment(
experiment, eval_every=int(1e4), num_eval_episodes=5)
if __name__ == '__main__':
jax.config.config_with_absl()
app.run(main)
The problem is with the default value of num_parallel_calls
in make_reverb_dataset
. By default it creates 12 workers to fetch data, each of them does the batching on its own. Hence even if in Reverb there are batch_size elements available, it can happen that elements will be sampled by different workers, none of the batched will fill up... which results in a hang. I will discuss with the team what to do about make_reverb_dataset
implementation. In the meantime you should be good setting num_parallel_calls
to 1.
@qstanczyk Yes! I tried setting to smaller values and that seems to help.
I also found tf.data.AUTOTUNE
would break things. This is true even when I completely remove the training (e.g., comment off update step) in my learner.
I was just going to post here that I found that num_parallel_calls seem to play a role. I had the suspicion that if somehow num_parallel_calls does not divide the batch size things may not work, but it seems that even setting num_parallel_calls to something like 8 would still result in a deadlock.
Looking forward to the fixes. I will test if num_parallel_calls unblocks my issue. Thanks again for the help, you may have saved me a few more days of trying to figure out what's wrong.
OK so now the dead lock behavior is gone I still get the issue mentioned in https://github.com/deepmind/acme/issues/235
Attaching the latest stacktrace
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/environment_loop.py", line 176, in run
result = self.run_episode()
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/environment_loop.py", line 115, in run_episode
self._actor.update()
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/jax/layouts/local_layout.py", line 140, in update
super().update()
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/agents/agent.py", line 105, in update
self._batch_size_upper_bounds = [
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/agents/agent.py", line 106, in <listcomp>
math.ceil(t.info.rate_limiter_info.sample_stats.completed /
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/reverb/server.py", line 229, in info
return reverb_types.TableInfo.from_serialized_proto(proto_string)
File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/reverb/reverb_types.py", line 80, in from_serialized_proto
proto = schema_pb2.TableInfo.FromString(proto_string)
BufferError: INVALID_ARGUMENT: Python buffer protocol is only defined for CPU buffers.
[reverb/cc/platform/default/server.cc:84] Shutting down replay server
E0531 00:47:26.028071 139694191777600 base.py:130] Timeout (10000 ms) exceeded when flushing the writer before deleting it. Caught Reverb exception: Flush call did not complete within provided timeout of 0:00:10
My protobuf version is 3.20.1
Hi,
I also experienced deadlocks while training a custom pixel-based agent, but this issue has been occurring both in Local and Distributed Layout, and before the recent changes in implementation. One thing curious thing I noticed is that the random deadlock only happens when using multiple GPUs, but not with a single GPU, and seems to be related to pmap itself.
I mentioned this problem here: https://github.com/google/jax/discussions/10763 ANd there seem to be other people encountering the issue on other JAX codes as well. One of JAX's maintainer managed to reproduce this deadlock of the pmaped function with someone's code.
Maybe the pmap JAX bug on GPUs is actually the underlying problem here too ?
this change makes non-distributed setup equivalent to the distributed setup from the "deadlock" perspective. As long as rate limiters are configured properly both setups should work fine.
@qstanczyk that looks good! Maybe this change should also propagate to LocalLayout?
LocalLayout should go away soon, it is being replaced with run_experiment.
@qstanczyk There are sometimes good reasons for creating custom training loops and it would be great if some way of creating a single process agent from builder is still available somewhere for users to implement single process agents. Initially I thought local layout should be the way to go and alternative ways are also ok
It is still possible to build a custom training loop for a single process agent by cloning run_experiment. We try to move away from the Layout design to make it easier to understand the agent's code for new users.
Right. That makes sense to me. Still I think it can be a lot of duplication if forking run_experiment would require keeping a copy of something like _TrainingAdder in the forked copy. For me, I probably want to customize the training loop, but not necessarily something like _TrainingAdder
By the way, is the future plan for writing tests for agents through run_experiment? I saw the old agent test files removed but I still think there is a lot of value of having those instead of just having the examples. For my code, I use the tests extensively to guard against API changes in Acme and also for my refactoring.
In that case it should be fine to clone the loop itself, while referencing the original _TrainingAdder. We wanted to keep entire logic of the run_experiment (including _TrainingAdder) in a single file for the convenience of the reader. The other reason is that implementation of run_experiment might change, so _TrainingAdder could be modified / removed. But if you work with a fixed version of Acme or are fine with merging changes, it is ok to reference _TrainingAdder.
For the tests where logic of the test was equivalent to the logic of examples we try to deduplicate the code. But in some cases I believe it makes sense to have a standalone test too.
@qstanczyk Thanks for the clarification. Yeah the old tests are sometimes quite duplicated indeed. I was hoping some testing utilities exists that can help with testing an entire agent right from builder.
I don’t work with a fixed version of Acme and instead work against HEAD. There are a lot of very useful things that get added and I keep my code updated to follow those. I guess for now it’s indeed a good idea to just fork run_experiment. I do that any way since I need to set up other additional things. I will merge in changes if the mechanism for preventing blocking changes in the single process case.
Hi,
I recently started migrating my JAX agents to use the new LocalLayout, which incorporates the changes that simplify the setup for ensuring that running non-distributed agents would not block. I have noticed that I start to experience deadlock with my old parameters.
I have a pixel-based agent similar to D4PG. For reverb, I use a batch size of 256, sampler_per_insert of 128.0, and a sample per insert tolerance rate of 0.1 (following D4PG with the exception of SPI which is now 128). I use num_sgd_steps = 1. I have noticed that in this case, I can experience deadlock. The way I set up the data iterator and prefetch follows exactly what's going on in the master branch.
My understanding is that with the new API, things should not block on inserts since the rate limiter has been adjusted to ensure that this cannot happen. For some reason, however, I found that things may block on sample. With some debugging It looks like the deadlock happens here
https://github.com/deepmind/acme/blob/ad073a85319435246a5fac21978e1be655191778/acme/agents/agent.py#L106
I have tried to debug this problem without any luck. @qstanczyk I noticed that you made these changes recently, do you have any idea what may be causing this issue? AFAIK The difference compared to the previous version is that now the local agent uses the table's rate limiting behavior to control the actor learner stepping frequency, instead of doing in manually in the Agent, but I suspect that the because of prefetching something does not work out right.
Any tips for debugging this would be greatly appreciated! Happy to include more details on the setup!
Many thanks in advance!