thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
7.8k stars 1.12k forks source link

Using FQF to solve the CartPole problem, the total_reward of each episode is always around 8-11, how can I modify the code? #538

Closed jaried closed 2 years ago

jaried commented 2 years ago

Using FQF to solve the CartPole problem, the total_reward of each episode is always around 8-11, how can I modify the code?

My code is as follows:

import argparse
from tianshou.data import Batch
import gym
import numpy as np
import torch
import tianshou as ts
from tianshou.policy import FQFPolicy
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='CartPole-v0')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--eps-test', type=float, default=0.05)
    parser.add_argument('--eps-train', type=float, default=0.1)
    parser.add_argument('--buffer-size', type=int, default=20000)
    parser.add_argument('--lr', type=float, default=3e-3)
    parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
    parser.add_argument('--gamma', type=float, default=0.9)
    parser.add_argument('--num-fractions', type=int, default=32)
    parser.add_argument('--num-cosines', type=int, default=64)
    parser.add_argument('--ent-coef', type=float, default=10.)
    parser.add_argument('--n-step', type=int, default=3)
    parser.add_argument('--target-update-freq', type=int, default=320)
    parser.add_argument('--epoch', type=int, default=10)
    parser.add_argument('--step-per-epoch', type=int, default=10000)
    parser.add_argument('--step-per-collect', type=int, default=10)
    parser.add_argument('--update-per-step', type=float, default=0.1)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64])
    parser.add_argument('--training-num', type=int, default=10)
    parser.add_argument('--test-num', type=int, default=100)
    parser.add_argument('--logdir', type=str, default='log')
    parser.add_argument('--render', type=float, default=0.)
    parser.add_argument('--prioritized-replay', action="store_true", default=False)
    parser.add_argument('--alpha', type=float, default=0.6)
    parser.add_argument('--beta', type=float, default=0.4)
    parser.add_argument(
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
    )
    args = parser.parse_known_args()[0]
    return args

def main(args=get_args()):
    args.prioritized_replay = True
    args.gamma = .95
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    feature_net = Net(
        args.state_shape,
        args.hidden_sizes[-1],
        hidden_sizes=args.hidden_sizes[:-1],
        device=args.device,
        softmax=False
    )
    net = FullQuantileFunction(
        feature_net,
        args.action_shape,
        args.hidden_sizes,
        num_cosines=args.num_cosines,
        device=args.device
    )
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
    fraction_optim = torch.optim.RMSprop(
        fraction_net.parameters(), lr=args.fraction_lr
    )
    agent = FQFPolicy(
        net,
        optim,
        fraction_net,
        fraction_optim,
        args.gamma,
        args.num_fractions,
        args.ent_coef,
        args.n_step,
        target_update_freq=args.target_update_freq
    ).to(args.device)

    buffer = ts.data.ReplayBuffer(size=1_000_000)
    steps = 0
    total_rewards = []
    for episode in range(1, 1_000_000 + 1):
        if episode >= 100:
            mean_rewards = np.mean(total_rewards[-100:])
            if mean_rewards >= env.spec.reward_threshold:
                print(f'{episode},平均回报:{mean_rewards},达成目标!')
                break
        state = env.reset()
        total_reward = 0
        while True:
            # env.render()
            batch = Batch(obs=[state], info=None)
            action = agent(batch).act[0]
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            buffer.add(Batch(obs=[state], act=action, obs_next=[next_state], rew=reward, done=done, info=None))

            state = next_state
            steps += 1

            if steps % 10 == 0:
                batch, indice = buffer.sample(batch_size=args.batch_size)
                b_ret = agent.process_fn(batch, buffer, indice)
                agent.learn(b_ret)
            if done:
                print(f'Episode:{episode:3,d},Total_reward:{int(total_reward):5,d}')
                total_rewards.append(total_reward)
                break

if __name__ == '__main__':
    main(get_args())
Trinkle23897 commented 2 years ago

You need to add epsilon-greedy in the training loop.

jaried commented 2 years ago

Do not explore, is it set to 1. or 0.?

jaried commented 2 years ago

I added epsilon-greedy in the training loop,but it takes no effect.

import math
import argparse
from tianshou.data import Batch
import gym
import numpy as np
import torch
import tianshou as ts
from tianshou.policy import FQFPolicy
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='CartPole-v0')
    # parser.add_argument('--task', type=str, default='MountainCar-v0')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--eps-test', type=float, default=0.05)
    parser.add_argument('--eps-train', type=float, default=0.1)
    parser.add_argument('--buffer-size', type=int, default=20000)
    parser.add_argument('--lr', type=float, default=3e-3)
    parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
    parser.add_argument('--gamma', type=float, default=0.9)
    parser.add_argument('--num-fractions', type=int, default=32)
    parser.add_argument('--num-cosines', type=int, default=64)
    parser.add_argument('--ent-coef', type=float, default=10.)
    parser.add_argument('--n-step', type=int, default=3)
    parser.add_argument('--target-update-freq', type=int, default=320)
    parser.add_argument('--epoch', type=int, default=10)
    parser.add_argument('--step-per-epoch', type=int, default=10000)
    parser.add_argument('--step-per-collect', type=int, default=10)
    parser.add_argument('--update-per-step', type=float, default=0.1)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64])
    parser.add_argument('--training-num', type=int, default=10)
    parser.add_argument('--test-num', type=int, default=100)
    parser.add_argument('--logdir', type=str, default='log')
    parser.add_argument('--render', type=float, default=0.)
    parser.add_argument('--prioritized-replay', action="store_true", default=False)
    parser.add_argument('--alpha', type=float, default=0.6)
    parser.add_argument('--beta', type=float, default=0.4)
    parser.add_argument(
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
    )
    args = parser.parse_known_args()[0]
    return args

def main(args=get_args()):
    args.prioritized_replay = True
    args.gamma = .95
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    feature_net = Net(
        args.state_shape,
        args.hidden_sizes[-1],
        hidden_sizes=args.hidden_sizes[:-1],
        device=args.device,
        softmax=False
    )
    net = FullQuantileFunction(
        feature_net,
        args.action_shape,
        args.hidden_sizes,
        num_cosines=args.num_cosines,
        device=args.device
    )
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
    fraction_optim = torch.optim.RMSprop(
        fraction_net.parameters(), lr=args.fraction_lr
    )
    agent = FQFPolicy(
        net,
        optim,
        fraction_net,
        fraction_optim,
        args.gamma,
        args.num_fractions,
        args.ent_coef,
        args.n_step,
        target_update_freq=args.target_update_freq
    ).to(args.device)
    try:
        agent.load_state_dict(torch.load('FQF-CartPole.pth'))
    except Exception as e:
        print(e)
    buffer = ts.data.ReplayBuffer(size=1_000_000)

    steps = 0
    total_rewards = []
    batch_size = args.batch_size
    epsilon = 0.99
    agent.set_eps(epsilon)
    for episode in range(1, 1_000_000 + 1):
        buffer2 = ts.data.ReplayBuffer(size=200)
        if episode >= 100:
            mean_rewards = np.mean(total_rewards[-100:])
            if mean_rewards >= env.spec.reward_threshold:
                print(f'Episode:{episode:,d},平均回报:{mean_rewards},达成目标!')
                break
        state = env.reset()
        total_reward = 0
        while True:
            # env.render()

            batch = Batch(obs=[state], info=None)
            action = agent(batch).act[0]
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            if done and total_reward < env.spec.reward_threshold:
                reward = -1000
            buffer.add(Batch(obs=[state], act=action, obs_next=[next_state], rew=reward, done=done, info=None))
            buffer2.add(Batch(obs=[state], act=action, obs_next=[next_state], rew=reward, done=done, info=None))

            state = next_state
            steps += 1
            epsilon = max(epsilon - 0.0001, 0.01)
            agent.set_eps(epsilon)
            if steps % 10 == 0:
                batch, indice = buffer.sample(batch_size=batch_size)
                b_ret = agent.process_fn(batch, buffer, indice)
                loss = agent.learn(b_ret)
                # print(f"loss:{loss['loss']:,.4e}")
            if done:
                # len_buffer2 = len(buffer2)
                # batches = math.ceil(len_buffer2 * 1. / batch_size)
                # for i in range(batches):
                #     try:
                #         start = i * batch_size
                #         to = (i + 1) * batch_size
                #         batch = buffer2[start:to]
                #         indice = np.arange(start, start + len(batch))
                #         b_ret = agent.process_fn(batch, buffer2, indice)
                #         agent.learn(b_ret)
                #     except Exception as e:
                #         print(e)
                total_rewards.append(total_reward)
                average_reward = np.mean(total_rewards[-100:])
                try:
                    print(f'Episode: {episode:3,d}, Total_reward:{int(total_reward):5,d}, '
                          f'average_reward: {average_reward:7,.2f}, '
                          f'epsilon:{epsilon:.3f}, '
                          f"loss:{loss['loss']:,.4e}")
                except Exception as e:
                    print(e)
                try:
                    torch.save(agent.state_dict(), 'FQF-CartPole.pth')
                except Exception as e:
                    print(e)
                break

if __name__ == '__main__':
    main(get_args())
Trinkle23897 commented 2 years ago

https://github.com/thu-ml/tianshou/blob/c248b4f87e46d8fca229f29d5cabb15211c842e9/tianshou/data/collector.py#L244-L245

jaried commented 2 years ago

https://github.com/thu-ml/tianshou/blob/c248b4f87e46d8fca229f29d5cabb15211c842e9/tianshou/data/collector.py#L244-L245

But I don't use collector?

Trinkle23897 commented 2 years ago

I know. My point is you forgot to add this line to add some noise into your action, set_eps is just self.eps = eps without activation. https://github.com/thu-ml/tianshou/blob/c248b4f87e46d8fca229f29d5cabb15211c842e9/tianshou/policy/modelfree/dqn.py#L64-L66

https://github.com/thu-ml/tianshou/blob/c248b4f87e46d8fca229f29d5cabb15211c842e9/tianshou/policy/modelfree/dqn.py#L178-L188

https://github.com/thu-ml/tianshou/blob/c248b4f87e46d8fca229f29d5cabb15211c842e9/tianshou/policy/base.py#L88-L99

jaried commented 2 years ago

I modified the code according to your prompt, but the average reward started to drop when it reached only about 180. More than 2000 episodes still haven't completed the task. Where is the problem?

import math
import argparse
from tianshou.data import Batch
import gym
import numpy as np
import torch
import tianshou as ts
from tianshou.policy import FQFPolicy
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='CartPole-v0')
    # parser.add_argument('--task', type=str, default='MountainCar-v0')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--eps-test', type=float, default=0.05)
    parser.add_argument('--eps-train', type=float, default=0.1)
    parser.add_argument('--buffer-size', type=int, default=20000)
    parser.add_argument('--lr', type=float, default=3e-3)
    parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
    parser.add_argument('--gamma', type=float, default=0.9)
    parser.add_argument('--num-fractions', type=int, default=32)
    parser.add_argument('--num-cosines', type=int, default=64)
    parser.add_argument('--ent-coef', type=float, default=10.)
    parser.add_argument('--n-step', type=int, default=3)
    parser.add_argument('--target-update-freq', type=int, default=320)
    parser.add_argument('--epoch', type=int, default=10)
    parser.add_argument('--step-per-epoch', type=int, default=10000)
    parser.add_argument('--step-per-collect', type=int, default=10)
    parser.add_argument('--update-per-step', type=float, default=0.1)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64])
    parser.add_argument('--training-num', type=int, default=10)
    parser.add_argument('--test-num', type=int, default=100)
    parser.add_argument('--logdir', type=str, default='log')
    parser.add_argument('--render', type=float, default=0.)
    parser.add_argument('--prioritized-replay', action="store_true", default=False)
    parser.add_argument('--alpha', type=float, default=0.6)
    parser.add_argument('--beta', type=float, default=0.4)
    parser.add_argument(
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
    )
    args = parser.parse_known_args()[0]
    return args

def main(args=get_args()):
    args.prioritized_replay = True
    args.gamma = .95
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    feature_net = Net(
        args.state_shape,
        args.hidden_sizes[-1],
        hidden_sizes=args.hidden_sizes[:-1],
        device=args.device,
        softmax=False
    )
    net = FullQuantileFunction(
        feature_net,
        args.action_shape,
        args.hidden_sizes,
        num_cosines=args.num_cosines,
        device=args.device
    )
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
    fraction_optim = torch.optim.RMSprop(
        fraction_net.parameters(), lr=args.fraction_lr
    )
    agent = FQFPolicy(
        net,
        optim,
        fraction_net,
        fraction_optim,
        args.gamma,
        args.num_fractions,
        args.ent_coef,
        args.n_step,
        target_update_freq=args.target_update_freq
    ).to(args.device)
    try:
        agent.load_state_dict(torch.load('FQF-CartPole.pth'))
    except Exception as e:
        print(e)
    buffer = ts.data.ReplayBuffer(size=1_000_000)

    steps = 0
    total_rewards = []
    batch_size = args.batch_size
    epsilon = 0.99
    agent.set_eps(epsilon)
    for episode in range(1, 1_000_000 + 1):
        buffer2 = ts.data.ReplayBuffer(size=200)
        if episode >= 100:
            mean_rewards = np.mean(total_rewards[-100:])
            if mean_rewards >= env.spec.reward_threshold:
                print(f'Episode:{episode:,d},average rewards::{mean_rewards},get the GOAL!')
                break
        state = env.reset()
        total_reward = 0
        while True:
            # env.render()

            batch = Batch(obs=[state], info=None)
            action = agent(batch).act[0]
            action = agent.exploration_noise(action[np.newaxis], batch)[0]
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            # if done and total_reward < env.spec.reward_threshold:
            #     reward = -1000
            buffer.add(Batch(obs=[state], act=action, obs_next=[next_state], rew=reward, done=done, info=None))
            buffer2.add(Batch(obs=[state], act=action, obs_next=[next_state], rew=reward, done=done, info=None))

            state = next_state
            steps += 1
            epsilon = max(epsilon - 0.0001, 0.1)
            agent.set_eps(epsilon)
            if steps % 10 == 0:
                batch, indice = buffer.sample(batch_size=batch_size)
                b_ret = agent.process_fn(batch, buffer, indice)
                loss = agent.learn(b_ret)
                # print(f"loss:{loss['loss']:,.4e}")
            if done:
                # len_buffer2 = len(buffer2)
                # batches = math.ceil(len_buffer2 * 1. / batch_size)
                # for i in range(batches):
                #     try:
                #         start = i * batch_size
                #         to = (i + 1) * batch_size
                #         batch = buffer2[start:to]
                #         indice = np.arange(start, start + len(batch))
                #         b_ret = agent.process_fn(batch, buffer2, indice)
                #         agent.learn(b_ret)
                #     except Exception as e:
                #         print(e)
                total_rewards.append(total_reward)
                average_reward = np.mean(total_rewards[-100:])
                try:
                    print(f'Episode: {episode:3,d}, Total_reward:{int(total_reward):5,d}, '
                          f'average_reward: {average_reward:7,.2f}, '
                          f'epsilon:{epsilon:.3f}, '
                          f"loss:{loss['loss']:,.4e}")
                except Exception as e:
                    print(e)
                try:
                    torch.save(agent.state_dict(), 'FQF-CartPole.pth')
                except Exception as e:
                    print(e)
                break

if __name__ == '__main__':
    main(get_args())
Trinkle23897 commented 2 years ago

Could you please use vectorized environment to speed-up your training? This feature has been implemented in Collector, and the hyper-parameter is tuned on this vec_env setting. If you are not sure what's wrong in the current code, there's always a reference here: https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_fqf.py

If you insist on using a single environment, maybe the correct way is to play with those hparam and find a good one.

UPDATE: pay attention to epsilon-greedy hyper-parameter, this really matters the performance. epsilon = max(epsilon - 0.0001, 0.1) -> epsilon = max(epsilon - 0.0001, 0.01)