thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
7.95k stars 1.13k forks source link

how to convert Batch into ndarray/tensor #1064

Closed qmpzzpmq closed 7 months ago

qmpzzpmq commented 8 months ago

Hi, when I run my toy script, I find a bug have a question about this:

when batch input in Policy class's forward function, it is Batch type. but it will input to the model, and the model's obs input's type actually is tensor/ndarray, I cannot find the transition mechanism.

so my toy mechanism is:

import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
from tianshou.utils.net.common import Net

from gymnasium.envs.registration import register

task = "DummyAudioAmplify-v0"
register(
    id=task,
    entry_point="uim_sfit.envs.dummy:DummpyEnv",
    max_episode_steps=20,
)

lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))  # TensorBoard is supported!

# you can also try with SubprocVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])

# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
env = gym.make(task, render_mode="human")
state_shape = [x.shape for x in env.observation_space.values()]
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)

policy = ts.policy.DQNPolicy(
    model=net,
    optim=optim,
    discount_factor=gamma, 
    action_space=env.action_space,
    estimation_step=n_step,
    target_update_freq=target_freq
)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)  # because DQN uses epsilon-greedy method

result = ts.trainer.OffpolicyTrainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=epoch,
    step_per_epoch=step_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=test_num,
    batch_size=batch_size,
    update_per_step=1 / step_per_collect,
    train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
    test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
    logger=logger,
).run()
print(f"Finished training in {result.timing.total_time} seconds")

torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))

policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)

and the env is defined as:

import numpy as np
import soundfile as sf
import librosa

import torch
from gymnasium import Env
from gymnasium import spaces

AUDIO_PATH = "/data/data/aishell3/test/SSB0005.wav"

# https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
class DummpyEnv(Env):
    metadata = {"render_modes": ["human"]}
    def __init__(self, render_mode=None, size=3):
        super().__init__()
        self.size = size
        self.observation_space = spaces.Dict(
            {
                "agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
                "target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
            }
        )
        # We have 4 actions, corresponding to "right", "up", "left", "down"
        self.action_space = spaces.Discrete(4)

        """
        The following dictionary maps abstract actions from `self.action_space` to
        the direction we will walk in if that action is taken.
        I.e. 0 corresponds to "right", 1 to "up" etc.
        """
        self._action_to_direction = {
            0: np.array([1, 0]),
            1: np.array([0, 1]),
            2: np.array([-1, 0]),
            3: np.array([0, -1]),
        }
        self.audio_spec = librosa.stft(sf.read(AUDIO_PATH)[0])

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

    def _get_obs(self):
        return {
            "agent": torch.as_tensor(self._agent_location), 
            "target": torch.as_tensor(self._target_location),
        }

    def _get_info(self):
        return {
            "distance": np.linalg.norm(
                self._agent_location - self._target_location, ord=1
            )
        }

    def reset(self):
        self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)
        self._target_location = self._agent_location
        while np.array_equal(self._target_location, self._agent_location):
            self._target_location = self.np_random.integers(
                0, self.size, size=2, dtype=int
            )
        observation = self._get_obs()
        info = self._get_info()
        return observation, info

    def step(self, action):
        # Map the action (element of {0,1,2,3}) to the direction we walk in
        direction = self._action_to_direction[action]
        # We use `np.clip` to make sure we don't leave the grid
        self._agent_location = np.clip(
            self._agent_location + direction, 0, self.size - 1
        )
        # An episode is done iff the agent has reached the target
        terminated = np.array_equal(self._agent_location, self._target_location)
        reward = 1 if terminated else 0  # Binary sparse rewards
        observation = self.xiansuanl()
        info = self._get_info()

        # observation, reward, terminated, truncated, info
        return observation, reward, terminated, False, info

    def render(self):
        # TODO: using self._agent_location to amplify the audio
        audio_spec = self.audio_spec * np.repeat(
            np.expand_dims(self._location_to_amplify(), axis=1),
            self.audio_spec.shape[1],
            axis=1
        )
        return librosa.istft(audio_spec)

    def _location_to_amplify(self):
        size = self.audio_spec.shape[0]
        amplify = self._agent_location * 0.2 + 0.8
        size1 = size // 2
        size2 = size - size1
        amplify = np.concatenate(
            [
                np.array([amplify[0]] * size1), 
                np.array([amplify[1]] * size2)
            ], axis=0
        )
        return amplify

    def close(self):
        pass

the bug is:

/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/librosa/core/intervals.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import resource_filename
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:168: DeprecationWarning: WARN: Current gymnasium version requires that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.
  logger.deprecation(
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:181: DeprecationWarning: WARN: Current gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.
  logger.deprecation(
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:131: UserWarning: WARN: The obs returned by the `reset()` method was expecting a numpy array, actual type: <class 'torch.Tensor'>
  logger.warn(
/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/gymnasium/spaces/box.py:240: UserWarning: WARN: Casting input x to numpy array.
  gym.logger.warn("Casting input x to numpy array.")
Traceback (most recent call last):
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/haoyu.tang/.vscode-server/extensions/ms-python.python-2024.0.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/haoyu.tang/uim_sfit/test_pipeline.py", line 64, in <module>
    ).run()
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/base.py", line 441, in run
    deque(self, maxlen=0)  # feed the entire iterator into a zero-length deque
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/base.py", line 252, in __iter__
    self.reset()
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/base.py", line 237, in reset
    test_result = test_episode(
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/trainer/utils.py", line 27, in test_episode
    result = collector.collect(n_episode=n_episode)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/data/collector.py", line 279, in collect
    result = self.policy(self.data, last_state)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/policy/modelfree/dqn.py", line 160, in forward
    logits, hidden = model(obs_next, state=state, info=batch.info)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/utils/net/common.py", line 248, in forward
    logits = self.model(obs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/utils/net/common.py", line 142, in forward
    obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
  File "/home/haoyu.tang/.conda/envs/sfit/lib/python3.10/site-packages/tianshou/data/batch.py", line 689, in __len__
    raise TypeError(f"Object {obj} in {self} has no len()")
TypeError: Object 0 in Batch(
    target: tensor(0),
    agent: tensor(0),
) has no len()
MischaPanch commented 8 months ago

Hi. You seem to be using some older version of tianshou, could you pls try again after installing from master?

I just tried examples/discrete/discrete_dqn.py, and it runs through without problems.

Apart from switching to the latest version - a reason for this failure may be the fact that your observation space is a dict (in fact, I think it's very likely). Tianshou should in principle support this, but there has been a lot of work on improving the internals of tianshou recently, and we haven't focused on dict spaces.

We will include and test support for dict observations at some point in the future, but it's not our priority. It's usually easy to adjust your environment to have a Discrete/Box observation space - either in the env itself or through an env wrapper.

Some general recommendations for getting something running on your env:

  1. If you don't plan on adjusting the implementation of the algorithmic details, consider using high level interfaces. They have a more declarative syntax and should be much easier to use for you (see e.g., examples/discrete/discrete_dqn_hl.py).
  2. I recommend to never use the gym.register and gym.make mechanism for custom envs. You can easily just instantiate your env directly, and write a factory for it (or use a lambda)
  3. If you don't use gym.make, it would be trivial to write and use an env wrapper that turns your Dict space into Discrete, and make use of this wrapper in your EnvFactory

@opcode81 @Trinkle23897 FYI. I'm making a new issue for Dict space support

MischaPanch commented 8 months ago

If you want to have a look into it @qmpzzpmq, I set up #1065 for outlining the problem

qmpzzpmq commented 8 months ago

@MischaPanch thanks for your reply, my tianshou version is 0.5.1. I guess is the newest version of tianshou in pip.

MischaPanch commented 8 months ago

Could you install the version on master instead?

MischaPanch commented 7 months ago

Closing as stale. FYI: new version of tianshou has been released, you can install it with pip