Attempting to invoke torch.compile on any of the abovementioned classes results in similar errors (see below)
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)
from copy import deepcopy
import tqdm
import numpy as np
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from torchrl.objectives import SACLoss, ValueEstimators
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
EPS = 1e-7
class SMACCNet(MultiAgentNetBase):
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = 'cpu',
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
self.device = device
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def init_net_params(self, net):
def init_layer_params(layer):
if isinstance(layer, nn.Linear):
weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
if 'bias' in layer.state_dict():
torch.nn.init.zeros_(layer.bias)
net.apply(init_layer_params)
return net
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Linear(300, self.n_agent_outputs)
).to(self.device)
model = self.init_net_params(model)
return model
class CustomTanhTransform(D.transforms.TanhTransform):
def _inverse(self, y):
# Yoinked from SB3!!!
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
y = y.clamp(-1. + EPS, 1. - EPS)
return 0.5 * (y.log1p() - (-y).log1p())
def log_abs_det_jacobian(self, x, y):
# Yoinked from PyTorch TanhTransform!
'''
tl;dr log(1-tanh^2(x)) = log(sech^2(x))
= 2log(2/(e^x + e^(-x)))
= 2(log2 - log(e^x/(1 + e^(-2x)))
= 2(log2 - x - log(1 + e^(-2x)))
= 2(log2 - x - softplus(-2x))
'''
return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))
class TanhNormalStable(D.TransformedDistribution):
'''Numerically stable variant of TanhNormal. Employs clipping trick.'''
def __init__(self, loc, scale, event_dims = 1):
self._event_dims = event_dims
self._t = [
CustomTanhTransform()
]
self.update(loc, scale)
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - D.utils._sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + D.utils._sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
log_prob = torch.clamp(log_prob, min = math.log10(EPS))
return log_prob
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.loc = loc
self.scale = scale
if (
hasattr(self, "base_dist")
and (self.base_dist.base_dist.loc.shape == self.loc.shape)
and (self.base_dist.base_dist.scale.shape == self.scale.shape)
):
self.base_dist.base_dist.loc = self.loc
self.base_dist.base_dist.scale = self.scale
else:
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
super().__init__(base, self._t)
@property
def mode(self):
m = self.base_dist.base_dist.mean
for t in self.transforms:
m = t(m)
return m
# Main Function
if __name__ == "__main__":
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 1
EXPLORATION_STEPS = 256
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda:0"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 256
LR = 3e-4
UPDATE_STEPS_PER_EXPLORATION = 1
WARMUP_STEPS = 0
TRAIN_TIMESTEPS = int(1e7)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
def env_fn(mode, parallel = True, rew_scale = True):
if rew_scale:
terminate_scale = -3.0
forward_scale = 2.5
fall_scale = -3.0
else:
# Use the defaults from PZ
terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0
def base_env_fn():
return PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = NUM_AGENTS,
terminate_reward = terminate_scale,
forward_reward = forward_scale,
fall_reward = fall_scale,
shared_reward = False,
max_cycles = MAX_EPISODE_STEPS,
render_mode = mode,
device = DEVICE
)
env = base_env_fn # noqa: E731
def env_with_transforms():
init_env = env()
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("walker", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
)
)
return init_env
return env_with_transforms
train_env = env_fn(None, parallel = False)()
if train_env.is_closed:
train_env.start()
check_env_specs(train_env)
obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
action_dim = train_env.full_action_spec["walker", "action"].shape[-1]
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = NUM_AGENTS,
centralised = False,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
),
NormalParamExtractor(),
)
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = NUM_AGENTS,
centralised = True,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
)
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("walker", "observation")],
out_keys = [("walker", "loc"), ("walker", "scale")]
)
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("walker", "observation"), ("walker", "action")],
out_keys = [("walker", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("walker", "obs_act")],
out_keys = [("walker", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["walker", "action"],
in_keys = [("walker", "loc"), ("walker", "scale")],
out_keys = [("walker", "action")],
distribution_class = TanhNormalStable,
return_log_prob = True,
)
with torch.no_grad():
fake_td = train_env.fake_tensordict()
policy_actor(fake_td)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
# Can't compile these either...
policy_actor = torch.compile(policy_actor)
critic_actor = torch.compile(critic_actor)
with torch.no_grad():
reset_obs = train_env.reset()
reset_obs_clean = deepcopy(reset_obs)
action = policy_actor(reset_obs)
state_action_value = critic_actor(action)
reset_obs = train_env.reset()
reset_obs["walker", "action"] = torch.zeros((*reset_obs["walker", "observation"].shape[:-1], action_dim))
train_env.rand_action(reset_obs)
action = train_env.step(reset_obs)
collector = SyncDataCollector(
ParallelEnv(NUM_EXPLORE_WORKERS,
[
env_fn(None, parallel = False)
for _ in range(NUM_EXPLORE_WORKERS)
],
device = None,
mp_start_method = "spawn"
),
policy = policy_actor,
frames_per_batch = BATCH_SIZE,
max_frames_per_traj = -1,
total_frames = TRAIN_TIMESTEPS,
device = 'cpu',
reset_at_each_iter = False
)
# Dummy loss module
replay_buffer = TensorDictPrioritizedReplayBuffer(
alpha = 0.7,
beta = 0.9,
storage = LazyMemmapStorage(
1e5,
device = 'cpu',
scratch_dir = "temp/"
),
priority_key = "td_error",
batch_size = BATCH_SIZE,
)
sac_loss = SACLoss(actor_network = policy_actor,
qvalue_network = critic_actor,
num_qvalue_nets = 2,
loss_function = "l2",
delay_actor = False,
delay_qvalue = True,
alpha_init = 0.1,
)
sac_loss.set_keys(
action = ("walker", "action"),
state_action_value = ("walker", "state_action_value"),
reward = ("walker", "reward"),
done = ("walker", "done"),
terminated = ("walker", "terminated"),
)
sac_loss.make_value_estimator(
value_type = ValueEstimators.TD0,
gamma = 0.99,
)
# Compiling replay_buffer.sample works :D
@torch.compile(mode = "reduce-overhead")
def rb_sample():
td_sample = replay_buffer.sample()
if td_sample.device != torch.device(DEVICE):
td_sample = td_sample.to(
DEVICE,
non_blocking = False
)
else:
td_sample = td_sample.clone()
return td_sample
# This does not :P
@torch.compile(disable = True)
def test_compile():
td_sample = rb_sample()
return sac_loss(td_sample)
samples = 0
for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
tensordict = tensordict.reshape(-1)
samples += tensordict.numel()
replay_buffer.extend(tensordict.to('cpu', non_blocking = True))
pbar.write("Hey Hey!!! :D")
a = test_compile()
print(a)
collector.shutdown()
train_env.close()
File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 382, in <module>
a = test_compile()
File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 373, in test_compile
td_sample = rb_sample()
File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 374, in torch_dynamo_resume_in_test_compile_at_373
return sac_loss(td_sample)
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1582, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 125, in decorate_context
with ctx_factory():
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 126, in torch_dynamo_resume_in_decorate_context_at_125
return func(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 289, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 559, in forward
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 723, in _qvalue_v2_loss
target_value = self._compute_target_v2(tensordict)
File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 687, in _compute_target_v2
tensordict = tensordict.clone(False)
File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 692, in torch_dynamo_resume_in__compute_target_v2_at_687
), self.actor_network_params.to_module(self.actor_network):
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/utils.py", line 1189, in new_func
out = func(_self, *args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/base.py", line 949, in to_module
return self._to_module(
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/params.py", line 174, in new_func
out = getattr(self._param_td, name)(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 444, in _to_module
if value.is_empty():
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 348, in is_empty
if not item.is_empty():
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 348, in is_empty
if not item.is_empty():
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 348, in is_empty
if not item.is_empty():
[Previous line repeated 1 more time]
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/base.py", line 3350, in is_empty
for _ in self.keys(True, True):
File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
result = inner_convert(
File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/convert_frame.py", line 295, in _convert_frame_assert
cache_size = compute_cache_size(frame, cache_entry)
File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/cache_size.py", line 142, in compute_cache_size
if _has_same_id_matched_objs(frame, cache_entry):
File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/cache_size.py", line 123, in _has_same_id_matched_objs
if weakref_from_frame != weakref_from_cache_entry:
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/params.py", line 174, in new_func
out = getattr(self._param_td, name)(*args, **kwargs)
File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 505, in __ne__
raise KeyError(
KeyError: "keys in TensorDict(<omitted for brevity>) mismatch, got {'2', '4', '1', '0', '3'} and {'module'}"
[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
Expected behavior
Since non-compile version executes successfully, compile is expected to succeed.
System info
Describe the characteristic of your environment:
Describe how the library was installed (pip, source, ...)
Describe the bug
Attempting to invoke
torch.compile
on any of the abovementioned classes results in similar errors (see below)To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
Expected behavior
Since non-compile version executes successfully, compile is expected to succeed.
System info
Describe the characteristic of your environment:
Reason and Possible fixes
Perhaps it is due to the decorators you mentioned in discord?
Checklist