DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.38k stars 1.61k forks source link

Issue(HER with in SAC algorithm) #1892

Closed wadeKeith closed 2 months ago

wadeKeith commented 3 months ago

🐛 Bug

I have no idea, why my code can't run. Is that about vectorize env problem? or something. Please help me to check this code.

Code example

import numpy as np from env import UR5Env import math from stable_baselines3.common.env_checker import check_env import gymnasium as gym

from stable_baselines3 import SAC, HerReplayBuffer from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize import time import os

from utilize import linear_schedule

seed = 429 reset_arm_poses = [math.pi, -math.pi/2, -math.pi5/9, -math.pi4/9, math.pi/2, 0] reset_gripper_range = [0, 0.085] visual_sensor_params = { 'image_size': [128, 128], 'dist': 1.0, 'yaw': 90.0, 'pitch': -25.0, 'pos': [0.6, 0.0, 0.0525], 'fov': 75.0, 'near_val': 0.1, 'far_val': 5.0, 'show_vision': False } robot_params = { "reset_arm_poses": reset_arm_poses, "reset_gripper_range": reset_gripper_range, }

sim_params = {"use_gui":False, 'timestep':1/240, 'control_type':'joint', 'gripper_enable':False} env_kwargs_dict = {"sim_params":sim_params, "robot_params": robot_params, "visual_sensor_params": visual_sensor_params}

vec_env = make_vec_env(UR5Env, n_envs=1, env_kwargs = env_kwargs_dict, seed=seed) vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=False) model = SAC("MultiInputPolicy",vec_env, learning_rate = linear_schedule(1e-6), replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=1, goal_selection_strategy="future"), buffer_size = 1000000, learning_starts = 100, batch_size = 256, tau = 0.005, gamma = 0.99, train_freq = (1, "step"), #(2, "episode"), (5, "step") tensorboard_log = './logs', seed = seed, verbose=1, device='cuda') model.learn(total_timesteps=500000, log_interval=10, tb_log_name="ur5_robotiq140_sac", progress_bar=True) model.save("./model/ur5_robotiq140_sac") stats_path = os.path.join('./normalize_file/', "vec_normalize_sac.pkl") vec_env.save(stats_path)

vec_env.close() del model ,vec_env# remove to demonstrate saving and loading sim_params['use_gui'] = True env_kwargs_dict = {"sim_params":sim_params, "robot_params": robot_params, "visual_sensor_params": visual_sensor_params} vec_env = make_vec_env(UR5Env, n_envs=1, env_kwargs = env_kwargs_dict, seed=seed) vec_env = VecNormalize.load(stats_path, vec_env) vec_env.training = False vec_env.norm_reward = False model = SAC.load("./model/ur5_robotiq140_sac",env=vec_env) obs = vec_env.reset() dones=False while not dones: action, _states = model.predict(obs,deterministic=True) obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") vec_env.close() exit()

Relevant log output / Error message

Unable to sample before the end of the first episode. We recommend choosing a value for learning_starts that is greater than the maximum number of timesteps in the environment.
  File "/home/zxr/Documents/yincheng/Github/UR5/train_sac.py", line 76, in <module>
    model.learn(total_timesteps=500000,
RuntimeError: Unable to sample before the end of the first episode. We recommend choosing a value for learning_starts that is greater than the maximum number of timesteps in the environment.

System Info

OS: Linux-5.15.0-101-generic-x86_64-with-glibc2.31 # 111~20.04.1-Ubuntu SMP Mon Mar 11 15:44:43 UTC 2024

Checklist

araffin commented 3 months ago

hello, the error message is quite explicit no? you should increase the warmup phase (learning starts), and you should have a look at recommended parameters in the rl zoo.

wadeKeith commented 2 months ago

hello, the error message is quite explicit no? you should increase the warmup phase (learning starts), and you should have a look at recommended parameters in the rl zoo.

Thanks, the problem is solved. The train_freq need to be modified. Because the train_feq is too fast that the sample process can't work.