thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
8.02k stars 1.13k forks source link

TypeError: type 'Collector' is not subscriptable #1218

Closed 948702669 closed 1 month ago

948702669 commented 2 months ago
import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter

import tianshou as ts
from tianshou.data import CollectStats
from tianshou.utils.space_info import SpaceInfo

def main() -> None:
    task = "CartPole-v1"
    lr, epoch, batch_size = 1e-3, 10, 64
    train_num, test_num = 10, 100
    gamma, n_step, target_freq = 0.9, 3, 320
    buffer_size = 20000
    eps_train, eps_test = 0.1, 0.05
    step_per_epoch, step_per_collect = 10000, 10

    logger = ts.utils.TensorboardLogger(SummaryWriter("dqn"))  # TensorBoard is supported!
    # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html

    # You can also try SubprocVectorEnv, which will use parallelization
    train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
    test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])

    from tianshou.utils.net.common import Net

    # Note: You can easily define other networks.
    # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
    env = gym.make(task, render_mode="human")
    assert isinstance(env.action_space, gym.spaces.Discrete)
    space_info = SpaceInfo.from_env(env)
    state_shape = space_info.observation_info.obs_shape
    action_shape = space_info.action_info.action_shape
    net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
    optim = torch.optim.Adam(net.parameters(), lr=lr)

    policy: ts.policy.DQNPolicy = ts.policy.DQNPolicy(
        model=net,
        optim=optim,
        discount_factor=gamma,
        action_space=env.action_space,
        estimation_step=n_step,
        target_update_freq=target_freq,
    )
    train_collector = ts.data.Collector[CollectStats](
        policy,
        train_envs,
        ts.data.VectorReplayBuffer(buffer_size, train_num),
        exploration_noise=True,
    )
    test_collector = ts.data.Collector[CollectStats](
        policy,
        test_envs,
        exploration_noise=True,
    )  # because DQN uses epsilon-greedy method

    def stop_fn(mean_rewards: float) -> bool:
        if env.spec:
            if not env.spec.reward_threshold:
                return False
            else:
                return mean_rewards >= env.spec.reward_threshold
        return False

    result = ts.trainer.OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=epoch,
        step_per_epoch=step_per_epoch,
        step_per_collect=step_per_collect,
        episode_per_test=test_num,
        batch_size=batch_size,
        update_per_step=1 / step_per_collect,
        train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
        test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
        stop_fn=stop_fn,
        logger=logger,
    ).run()
    print(f"Finished training in {result.timing.total_time} seconds")

    # watch performance
    policy.set_eps(eps_test)
    collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True)
    collector.collect(n_episode=100, render=1 / 35)

if __name__ == "__main__":
    main()

When running the above code, "TypeError: type 'Collector' is not subscriptable" appears. How should I solve it?

opcode81 commented 2 months ago

Please provide the used versions as described in the issue template.

spfrommer commented 2 months ago

Also got the same error message:

1.1.0 0.29.1 2.4.1+cu121 1.26.4 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] linux

Running examples/mujoco/mujoco_ppo.py.

Anne-kj commented 1 month ago

When running test/discrete/test_sac.py, it got the same error. Using tianshou 1.1.0.

opcode81 commented 1 month ago

@948702669, @Anne-kj, @spfrommer, you cannot expect to be able to run source code from the current development branch (master) with any version of Tianshou that isn't the latest development version.

If you are using v1.1.0 (or any other older version), then you must also take the example code from that version, i.e. you must use the respective tag v1.1.0 and get your example code from there.

Please confirm that taking the correct example code resolves your issue.