Closed HandsomeAIccx closed 3 years ago
Could you please provide the scripts to reproduce this error? It seems you use a customized env.
class Net(nn.Module):
def __init__(self,space,action_shape,device):
super().__init__()
shape1 = space['birdeye'].shape
self.model1 = nn.Sequential(*[
nn.Linear(np.prod(shape1), 128),
nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape))
]).to('cuda')
shape2 = space['state'].shape
self.model2 = nn.Sequential(*[
nn.Linear(np.prod(shape2),4),
nn.ReLU(inplace=True),
nn.Linear(4,4),nn.ReLU(inplace=True),
nn.Linear(4,4),nn.ReLU(inplace=True),
nn.Linear(4,np.prod(action_shape))
]).to('cuda')
self.device = device
self.output_dim = 2
def forward(self,s,state=None,info={}):
# birdeye = s['birdeye']
# state = s['state']
birdeye = s.birdeye
state = s.state
if not isinstance(birdeye,torch.Tensor):
birdeye = torch.tensor(birdeye,dtype = torch.float,device='cuda')
if not isinstance(state,torch.Tensor):
state = torch.tensor(state,dtype = torch.float,device='cuda')
birdeye_batch = birdeye.shape[0]
state_batch = birdeye.shape[0]
birdeye_logits = self.model1(birdeye.view(birdeye_batch, -1))
state_logits = self.model2(state.view(state_batch,-1))
logits = birdeye_logits + state_logits
return logits,state
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='carla-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=10000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument("--start-timesteps", type=int, default=25000)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--step-per-collect', type=int, default=1)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--training-num', type=int, default=1)
parser.add_argument('--test-num', type=int, default=1)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
return parser.parse_args()
def test_ddpg(args=get_args()):
params = {
'number_of_vehicles': 100,
'number_of_walkers': 0,
'display_size': 256, # screen size of bird-eye render
'max_past_step': 1, # the number of past steps to draw
'dt': 0.1, # time interval between two frames
'discrete': False, # whether to use discrete control space
'discrete_acc': [-3.0, 0.0, 3.0], # discrete value of accelerations
'discrete_steer': [-0.2, 0.0, 0.2], # discrete value of steering angles
'continuous_accel_range': [-3.0, 3.0], # continuous acceleration range
'continuous_steer_range': [-0.3, 0.3], # continuous steering angle range
'ego_vehicle_filter': 'vehicle.lincoln*', # filter for defining ego vehicle
'port': 2000, # connection port
'town': 'Town03', # which town to simulate
'task_mode': 'random', # mode of the task, [random, roundabout (only for Town03)]
'max_time_episode': 1000, # maximum timesteps per episode
'max_waypt': 12, # maximum number of waypoints
'obs_range': 32, # observation range (meter)
'lidar_bin': 0.125, # bin size of lidar sensor (meter)
'd_behind': 12, # distance behind the ego vehicle (meter)
'out_lane_thres': 2.0, # threshold for out of lane
'desired_speed': 8, # desired speed (m/s)
'max_ego_spawn_times': 200, # maximum times to spawn ego vehicle
'display_route': True, # whether to render the desired route
'pixor_size': 64, # size of the pixor labels
'pixor': False, # whether to output PIXOR observation
}
env = gym.make(args.task,params=params)
# args.state_shape = env.observation_space.shape
args.state_shape = env.observation_space.spaces['birdeye']
args.state_shape = args.state_shape.shape
args.action_shape = env.action_space.shape
args.max_action = env.action_space.high[0]
args.exploration_noise = args.exploration_noise * args.max_action
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
if args.test_num>1:
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task,params=params) for _ in range(args.test_num)]
)
else:
test_envs = gym.make(args.task,params=params)
if args.training_num > 1:
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task,params=params) for _ in range(args.training_num)]
)
else:
train_envs = gym.make(args.task,params=params)
# test_envs = gym.make(args.task)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(env.observation_space, action_shape=args.action_shape, device=args.device)
print(args.action_shape,args.max_action,args.device)
actor = Actor(
net_a, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c = Net(
env.observation_space,
args.action_shape,
device=args.device
)
critic = Critic(net_c, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor,
actor_optim,
critic,
critic_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step,
action_space=env.action_space
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collectortraining_num
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, train_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ddpg'
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
if not args.watch:
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
)
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
test_ddpg()
My env is based on carla. I have to set both training-num and test-num to 1 because carla doesn't allow to create more than one env in one machine.
Here's the intermediate result in forward
function:
Object -0.025024621697931952 in Batch(A:array,B:array,C:float,D:array) has no len()
So what is ABCD? I cannot find it in your screenshot.
It seems that there's a float state
which is not an numpy array. Maybe that's the root cause.
Sorry, ABCD is for birdeye / camera / state / lidar. The detailed output is
TypeError: Object 0.10905823607911809 in Batch(
camera: array([[135, 147, 169],
[135, 147, 170]
.......................
[136, 149, 172]], dtype=uint8),
birdeye: array([[185, 188, 181],
[ 82, 87, 87],
.....................
[ 0, 0, 0]], dtype=uint8),
state: 0.10905823607911809,
lidar: array([[ 0, 0, 0],
.................
[ 0, 0, 0]], dtype=uint8),
) has no len()
Seems state
is problematic.
However, the state in screenshot is an np.array with shape [1, 4] so that's fine?
What's the state
in your network forward function, is it for recurrent?
def forward(self,s,state=None,info={}):
# birdeye = s['birdeye']
# state = s['state']
birdeye = s.birdeye
state = s.state
if not isinstance(birdeye,torch.Tensor):
birdeye = torch.tensor(birdeye,dtype = torch.float,device='cuda')
if not isinstance(state,torch.Tensor):
state = torch.tensor(state,dtype = torch.float,device='cuda')
birdeye_batch = birdeye.shape[0]
state_batch = birdeye.shape[0]
birdeye_logits = self.model1(birdeye.view(birdeye_batch, -1))
state_logits = self.model2(state.view(state_batch,-1))
logits = birdeye_logits + state_logits
return logits,state
Here, the returned state
will be next input state
, if you don't want to use recurrent mechanism, you'd better use return logits, None
. Otherwise the state
is confusing in this code snippet.
Indeed but I think this is not the root cause.
I find the s
in forward function is
The batch
in __len__
is
The birdeye and state's dimension change a lot, state goes to float64. May I ask the potential reason? Is it because of network?
Try to use a mock env to debug.
It is because DDPG's critic network need obs
and act
as the input (to calculate Q(s, a)). However, your environment's observation is a dict (i.e., nested), but you only define one network for nested_obs -> logits
, yet lack of the network (nested_obs, act) -> Q(s, a)
. (It is not defined how to concatenate these two things.)
To resolve this issue, you need to define your own Critic network that receives nested_obs and action, and output Q(s, a). Try to start with inheriting the existing Critic
network class for the first step.
May I ask what's going wrong with this error message?