Closed N00bcak closed 1 month ago
modules
is recursive, unlike children
.
I think you're right with 2. though, if we cache the vmap call as we do we can't get this to work.
I'll push a fix shortly
modules
is recursive, unlikechildren
. I think you're right with 2. though, if we cache the vmap call as we do we can't get this to work. I'll push a fix shortly
Hmm, that's surprising. Were it truly recursive, I'd expect _vmap_randomness == "different"
at the end of it, because the double-break should prevent the else
clause from triggering.
I've got something running atm, so I can't provide proof of this just yet, but 1. was what I observed when stepping.
I fixed a couple more things, but I can't try your example because i'm (as always) having problems with petting zoo dependencies maybe you can check that it works, or perhaps give an example that does not involve an extra lib?
I fixed a couple more things, but I can't try your example because i'm (as always) having problems with petting zoo dependencies maybe you can check that it works, or perhaps give an example that does not involve an extra lib?
FWIW, I was able to replicate this issue on "navigation"
in VMAS (I figured you could run it since its featured with MAPPO as a tutorial :P):
For a little extra information, I patched LossModule.vmap_randomness
:
@property
def vmap_randomness(self):
modules = []
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
modules.append(str(type(module)))
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"
print(','.join(modules))
return self._vmap_randomness
This is the script proper:
# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)
from copy import deepcopy
import tqdm
import numpy as np
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.envs import check_env_specs, VmasEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from torchrl.objectives import SACLoss, ValueEstimators
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
EPS = 1e-7
class SMACCNet(MultiAgentNetBase):
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = 'cpu',
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
self.device = device
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def init_net_params(self, net):
def init_layer_params(layer):
if isinstance(layer, nn.Linear):
weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
if 'bias' in layer.state_dict():
torch.nn.init.zeros_(layer.bias)
net.apply(init_layer_params)
return net
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Dropout(0.5), # <- The dropout is here!
nn.Linear(300, self.n_agent_outputs)
).to(self.device)
model = self.init_net_params(model)
return model
class CustomTanhTransform(D.transforms.TanhTransform):
def _inverse(self, y):
# Yoinked from SB3!!!
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
y = y.clamp(-1. + EPS, 1. - EPS)
return 0.5 * (y.log1p() - (-y).log1p())
def log_abs_det_jacobian(self, x, y):
# Yoinked from PyTorch TanhTransform!
'''
tl;dr log(1-tanh^2(x)) = log(sech^2(x))
= 2log(2/(e^x + e^(-x)))
= 2(log2 - log(e^x/(1 + e^(-2x)))
= 2(log2 - x - log(1 + e^(-2x)))
= 2(log2 - x - softplus(-2x))
'''
return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))
class TanhNormalStable(D.TransformedDistribution):
'''Numerically stable variant of TanhNormal. Employs clipping trick.'''
def __init__(self, loc, scale, event_dims = 1):
self._event_dims = event_dims
self._t = [
CustomTanhTransform()
]
self.update(loc, scale)
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - D.utils._sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + D.utils._sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
log_prob = torch.clamp(log_prob, min = math.log10(EPS))
return log_prob
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.loc = loc
self.scale = scale
if (
hasattr(self, "base_dist")
and (self.base_dist.base_dist.loc.shape == self.loc.shape)
and (self.base_dist.base_dist.scale.shape == self.scale.shape)
):
self.base_dist.base_dist.loc = self.loc
self.base_dist.base_dist.scale = self.scale
else:
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
super().__init__(base, self._t)
@property
def mode(self):
m = self.base_dist.base_dist.mean
for t in self.transforms:
m = t(m)
return m
# Main Function
if __name__ == "__main__":
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 1
EXPLORATION_STEPS = 256
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda:0"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 256
LR = 3e-4
UPDATE_STEPS_PER_EXPLORATION = 1
WARMUP_STEPS = 0
TRAIN_TIMESTEPS = int(1e7)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
def env_fn():
def base_env_fn():
return VmasEnv(
scenario="navigation",
num_envs=NUM_EXPLORE_WORKERS,
continuous_actions=True,
max_steps=200,
device="cpu",
seed=None,
# Scenario kwargs
n_agents=NUM_AGENTS,
)
env = base_env_fn # noqa: E731
def env_with_transforms():
init_env = env()
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("agents", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
)
)
return init_env
return env_with_transforms
train_env = env_fn()()
if train_env.is_closed:
train_env.start()
check_env_specs(train_env)
# print(train_env.full_observation_spec)
# print(train_env.full_action_spec)
print(train_env.done_spec)
# breakpoint()
obs_dim = train_env.full_observation_spec["agents", "observation"].shape[-1]
action_dim = train_env.full_action_spec["agents", "action"].shape[-1]
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = NUM_AGENTS,
centralised = False,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
),
NormalParamExtractor(),
).to(DEVICE)
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = NUM_AGENTS,
centralised = True,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
).to(DEVICE)
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("agents", "observation")],
out_keys = [("agents", "loc"), ("agents", "scale")]
)
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("agents", "observation"), ("agents", "action")],
out_keys = [("agents", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("agents", "obs_act")],
out_keys = [("agents", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["agents", "action"],
in_keys = [("agents", "loc"), ("agents", "scale")],
out_keys = [("agents", "action")],
distribution_class = TanhNormalStable,
return_log_prob = True,
)
# with torch.no_grad():
# fake_td = train_env.fake_tensordict()
# policy_actor(fake_td)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
# with torch.no_grad():
# reset_obs = train_env.reset()
# reset_obs_clean = deepcopy(reset_obs)
# action = policy_actor(reset_obs)
# state_action_value = critic_actor(action)
# reset_obs = train_env.reset()
# reset_obs["agents", "action"] = torch.zeros((*reset_obs["agents", "observation"].shape[:-1], action_dim))
# train_env.rand_action(reset_obs)
# action = train_env.step(reset_obs)
collector = SyncDataCollector(
ParallelEnv(NUM_EXPLORE_WORKERS,
[
env_fn()
for _ in range(NUM_EXPLORE_WORKERS)
],
device = None,
mp_start_method = "spawn"
),
policy = policy_actor,
frames_per_batch = BATCH_SIZE,
max_frames_per_traj = -1,
total_frames = TRAIN_TIMESTEPS,
device = 'cpu',
policy_device = 'cpu',
reset_at_each_iter = False
)
# Dummy loss module
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha = 0.7,
beta = 0.9,
storage = LazyMemmapStorage(
1e5,
device = 'cpu',
scratch_dir = "googoogaagaa/"
),
priority_key = "td_error",
batch_size = BATCH_SIZE,
)
sac_loss = SACLoss(actor_network = policy_actor,
qvalue_network = critic_actor,
num_qvalue_nets = 2,
loss_function = "l2",
delay_actor = False,
delay_qvalue = True,
alpha_init = 0.1,
)
sac_loss.set_keys(
action = ("agents", "action"),
state_action_value = ("agents", "state_action_value"),
reward = ("agents", "reward"),
done = ("agents", "done"),
terminated = ("agents", "terminated"),
)
sac_loss.make_value_estimator(
value_type = ValueEstimators.TD0,
gamma = 0.99,
)
# Compiling replay_buffer.sample works :D
@torch.compile(mode = "reduce-overhead")
def rb_sample():
td_sample = replay_buffer.sample()
if td_sample.device != torch.device(DEVICE):
td_sample = td_sample.to(
DEVICE,
non_blocking = False
)
else:
td_sample = td_sample.clone()
return td_sample
def test_compile():
td_sample = rb_sample()
return sac_loss(td_sample)
samples = 0
for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
tensordict.set(
("next", "agents", "done"),
tensordict.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict.get_item_shape(("next", "agents", "reward"))),
)
tensordict.set(
("next", "agents", "terminated"),
tensordict.get(("next", "terminated"))
.unsqueeze(-1)
.expand(tensordict.get_item_shape(("next", "agents", "reward"))),
)
tensordict = tensordict.reshape(-1)
samples += tensordict.numel()
replay_buffer.extend(tensordict.to('cpu', non_blocking = True))
pbar.write("Hey Hey!!! :D")
a = test_compile()
print(a)
collector.shutdown()
train_env.close()
Running the script now yields
<class 'torchrl.modules.tensordict_module.actors.ProbabilisticActor'>,<class 'torch.nn.modules.container.ModuleList'>,<class 'tensordict.nn.common.TensorDictModule'>,<class 'torch.nn.modules.container.Sequential'>,<class '__main__.SMACCNet'>,<class 'tensordict.nn.params.TensorDictParams'>,<class 'tensordict.nn.distributions.continuous.NormalParamExtractor'>,<class 'torchrl.modules.tensordict_module.probabilistic.SafeProbabilisticModule'>,<class 'tensordict.nn.sequence.TensorDictSequential'>,<class 'torch.nn.modules.container.ModuleList'>,<class 'tensordict.nn.common.TensorDictModule'>,<class 'tensordict.nn.common.TensorDictModule'>,<class '__main__.SMACCNet'>,<class 'tensordict.nn.params.TensorDictParams'>
<...omitted for brevity...>
...
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/dropout.py", line 59, in forward
return F.dropout(input, self.p, self.training, self.inplace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/functional.py", line 1295, in dropout
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
Seems that the list of modules being checked does not go deep enough because DropoutNd
is nowhere to be seen :P
Got it, here your problem is that the dropout is hidden by the MARL model which does not register the inner module in a usual way. Should be somewhat easy to fix
Describe the bug
SACLoss
has flawed checks for determining the nature ofvmap_randomness
. Therefore, stochastic modules cannot be used in constituent networks.To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
Expected behavior
SACLoss
module performs forward passes successfully.System info
Describe the characteristic of your environment:
Reason and Possible fixes
There are essentially two reasons for this error:
RANDOM_MODULE_LIST
LossModule.set_vmap_randomness
asself.vmap_randomness
is accessed immediately during initialization timeChecklist