DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.19k stars 1.71k forks source link

[Bug]: How to avoid saving an external LLM model while saving a cutomized dqn policy #2025

Closed chrisgao99 closed 1 month ago

chrisgao99 commented 1 month ago

🐛 Bug

Hello,

I wrote a customized DQN policy trying to use Large Language Model to modify q-value before the dqn policy predicts an action. The idea is simple: every time the agent gets an obs, it querry llm for an expert log prob and add the log prob to the original q_values so that it can choose a better action.

My problem is when sb3 saved best model checkpoint, it will save the external LLM part together and lead to an error. I wonder if there are any method to stop saving the part of my customized dqn policy in checkpoint? I tried to modify the save() in DQN() but it still gives me the same error.

Here is my full code:

import numpy as np
import torch as th
from stable_baselines3 import DQN
import gymnasium as gym
import gymnasium.spaces as spaces
from typing import Union, Optional, Tuple, Dict
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from gpt_api import llm_prob, llm_prob_parallel
from vllm import LLM, SamplingParams
from expert_i2p import LMExpert

TensorDict = Dict[str, th.Tensor]
PyTorchObs = Union[th.Tensor, TensorDict]

class LangQNetwork(QNetwork):
    def __init__(self, *args, **kwargs):
        self.vlm_model = kwargs.pop("vlm_model", None)
        self.vlm_model_name = kwargs.pop("vlm_model_name",None)
        super().__init__(*args, **kwargs)
        # self.expert_policy = ExpertPolicy()
        print("LangQNetwork expert policy initialized:")
        print(self.vlm_model)

    @th.no_grad()
    def query_expert(self, screen_image: np.array, device: str, vlm_model, vlm_model_name) -> th.tensor:
        p = LMExpert(vlm_model,vlm_model_name,screen_image)
        p = th.tensor(p).to(device)
        return p

    def _predict(self, observation: PyTorchObs, deterministic: bool = True, screen_image: np.array = None, use_expert: bool = False) -> th.Tensor:
        q_values = self(observation) # observation and q_values device: cuda
        if use_expert:
            log_p = self.query_expert(screen_image, q_values.device,self.vlm_model,self.vlm_model_name)
            log_p = th.tensor(log_p)
            print("Log probabilities: ", log_p)
            q_values = q_values + log_p
            print("Modified Q-values: ", q_values)
        action = q_values.argmax(dim=1).reshape(-1)
        return action

class CustomDQNPolicy(DQNPolicy):
    q_net: LangQNetwork
    q_net_target: LangQNetwork

    def __init__(self, *args, **kwargs):
        vlm_model = kwargs.pop("vlm_model", None)
        if vlm_model != None:
            print("VLM model is",vlm_model)
        else:
            print("VLM is not passed")
        self.vlm_model = vlm_model
        self.vlm_model_name = kwargs.pop("vlm_model_name",None)
        super().__init__(*args, **kwargs)

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
        screen_image: np.array = None,
        use_expert: bool = False
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        # Switch to eval mode (this affects batch norm / dropout)

        # Check for common mistake that the user does not mix Gym/VecEnv API
        # Tuple obs are not supported by SB3, so we can safely do that check
        if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
            raise ValueError(
                "You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
                "You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
                "vs `obs = vec_env.reset()` (SB3 VecEnv). "
                "See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
                "and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
            )

        obs_tensor, vectorized_env = self.obs_to_tensor(observation)

        with th.no_grad():
            actions = self._predict(obs_tensor, deterministic=deterministic, screen_image=screen_image,use_expert=use_expert)
        # Convert to numpy, and reshape to the original action shape
        actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))  # type: ignore[misc]

        if isinstance(self.action_space, spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)  # type: ignore[assignment, arg-type]
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)  # type: ignore[assignment, arg-type]

        # Remove batch dimension if needed
        if not vectorized_env:
            assert isinstance(actions, np.ndarray)
            actions = actions.squeeze(axis=0)

        return actions, state  # type: ignore[return-value]

    def make_q_net(self) -> LangQNetwork:
        net_args = self._update_features_extractor(self.net_args, features_extractor=None)
        net_args["vlm_model"] = self.vlm_model      # input vlm_model to q network
        net_args["vlm_model_name"] = self.vlm_model_name
        return LangQNetwork(**net_args).to(self.device)
    def _predict(self, obs: PyTorchObs, deterministic: bool = True, screen_image: np.array = None, use_expert: bool = False) -> th.Tensor:
        return self.q_net._predict(obs, deterministic=deterministic, screen_image=screen_image, use_expert=use_expert)

class CustomDQN(DQN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_expert = False
        self.collect_count = 0

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        if not deterministic and np.random.rand() < self.exploration_rate:
            if self.policy.is_vectorized_observation(observation):
                if isinstance(observation, dict):
                    n_batch = observation[next(iter(observation.keys()))].shape[0]
                else:
                    n_batch = observation.shape[0]
                action = np.array([self.action_space.sample() for _ in range(n_batch)])
            else:
                action = np.array(self.action_space.sample())
        else:
            image = self.env.get_images()
            action, state = self.policy.predict(observation, state, episode_start, deterministic, screen_image=image, use_expert=self.use_expert)
        return action, state

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        train_freq: TrainFreq,
        replay_buffer: ReplayBuffer,
        action_noise: Optional[ActionNoise] = None,
        learning_starts: int = 0,
        log_interval: Optional[int] = None,
    ) -> RolloutReturn:
        """
        Collect experiences and store them into a ``ReplayBuffer``.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param train_freq: How much experience to collect
            by doing rollouts of current policy.
            Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
            or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
            with ``<n>`` being an integer greater than 0.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param replay_buffer:
        :param log_interval: Log data every ``log_interval`` episodes
        :return:
        """
        # Switch to eval mode (this affects batch norm / dropout)

        self.policy.set_training_mode(False)

        num_collected_steps, num_collected_episodes = 0, 0

        assert isinstance(env, VecEnv), "You must pass a VecEnv"
        assert train_freq.frequency > 0, "Should at least collect one step or episode."

        if env.num_envs > 1:
            assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."

        if self.use_sde:
            self.actor.reset_noise(env.num_envs)

        callback.on_rollout_start()
        continue_training = True
        while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
            if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.actor.reset_noise(env.num_envs)

            # turn on the expert policy before collecting data 
            if self.collect_count % 5 == 0:
                self.use_expert = True
            else:
                self.use_expert = False
            actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
            self.collect_count += 1

            # Rescale and perform action
            new_obs, rewards, dones, infos = env.step(actions)

            self.num_timesteps += env.num_envs
            num_collected_steps += 1

            # Give access to local variables
            self.use_expert = False    # turn off the expert policy before eval_callback
            callback.update_locals(locals())
            # Only stop training if return value is False, not when it is None.
            if not callback.on_step():
                return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)

            # Retrieve reward and episode length if using Monitor wrapper
            self._update_info_buffer(infos, dones)

            # Store data in replay buffer (normalized action and unnormalized observation)
            self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)  # type: ignore[arg-type]

            self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

            # For DQN, check if the target network should be updated
            # and update the exploration schedule
            # For SAC/TD3, the update is dones as the same time as the gradient update
            # see https://github.com/hill-a/stable-baselines/issues/900
            self._on_step()

            for idx, done in enumerate(dones):
                if done:
                    # Update stats
                    num_collected_episodes += 1
                    self._episode_num += 1

                    if action_noise is not None:
                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
                        action_noise.reset(**kwargs)

                    # Log training infos
                    if log_interval is not None and self._episode_num % log_interval == 0:
                        self._dump_logs()
        callback.on_rollout_end()

        return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)

    def save(self, path, exclude=None, include=None):
        # Temporarily remove the LLM from the policy's attributes to avoid pickling it
        if hasattr(self.policy, 'vlm_model'):
            vlm_model_backup = self.policy.vlm_model
            del self.policy.vlm_model  # Remove the LLM model

        # Call the original save method
        super().save(path, exclude=exclude, include=include)

        # Reassign the LLM model after saving
        if 'vlm_model_backup' in locals():
            self.policy.vlm_model = vlm_model_backup

def env_creator5():
    env = gym.make('LunarLander-v2',render_mode='rgb_array')
    return env

if __name__=="__main__":
    env = make_vec_env(lambda: env_creator5(), n_envs=2, vec_env_cls=SubprocVecEnv,seed=0)

    eval_env = make_vec_env(lambda: env_creator5(), n_envs=5, vec_env_cls=SubprocVecEnv,seed=0)

    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path='logs/llmguide_dqn/best_model',
        log_path='logs/llmguide_dqn/results',
        eval_freq=1000,  # Evaluate every n_env*eval_freq steps
        n_eval_episodes=5,  # Number of episodes for evaluation
        deterministic=True,
        render=False
    )
    vlm_model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    download_dir = "/scratch/baz7dy/.cache"
    VLM = LLM(
        model=vlm_model_name,
        max_model_len=4096,
        max_num_seqs=10,
        enforce_eager=True,
        tensor_parallel_size = 4,
        download_dir=download_dir,
        dtype="float16",
    )

    model = CustomDQN(CustomDQNPolicy, env,verbose=1,tensorboard_log="logs/tb/llmguide_dqn",learning_starts=10000,
            policy_kwargs={"vlm_model":VLM,"vlm_model_name":vlm_model_name})

    model.learn(total_timesteps=1e7,callback=eval_callback)

    model.save("llmguide_dqn")

The key part is the querry_expert function in the customized QNetwork:

class LangQNetwork(QNetwork):
    def __init__(self, *args, **kwargs):
        self.vlm_model = kwargs.pop("vlm_model", None)
        self.vlm_model_name = kwargs.pop("vlm_model_name",None)
        super().__init__(*args, **kwargs)
        # self.expert_policy = ExpertPolicy()
        print("LangQNetwork expert policy initialized:")
        print(self.vlm_model)

    @th.no_grad()
    def query_expert(self, screen_image: np.array, device: str, vlm_model, vlm_model_name) -> th.tensor:
        p = LMExpert(vlm_model,vlm_model_name,screen_image)
        p = th.tensor(p).to(device)
        return p

    def _predict(self, observation: PyTorchObs, deterministic: bool = True, screen_image: np.array = None, use_expert: bool = False) -> th.Tensor:
        q_values = self(observation) # observation and q_values device: cuda
        if use_expert:
            log_p = self.query_expert(screen_image, q_values.device,self.vlm_model,self.vlm_model_name)
            log_p = th.tensor(log_p)
            print("Log probabilities: ", log_p)
            q_values = q_values + log_p
            print("Modified Q-values: ", q_values)
        action = q_values.argmax(dim=1).reshape(-1)
        return action

Could anyone give me some suggestions to avoid saving the query_expert() part during saving a checkpoint.

To Reproduce

No response

Relevant log output / Error message

[rank0]: Traceback (most recent call last):
[rank0]:   File "/sfs/weka/scratch/baz7dy/INTP/lunar_trails/sb3_dqn2.py", line 349, in <module>
[rank0]:     model.learn(total_timesteps=1e7,callback=eval_callback)
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/dqn/dqn.py", line 267, in learn
[rank0]:     return super().learn(
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 328, in learn
[rank0]:     rollout = self.collect_rollouts(
[rank0]:   File "/sfs/weka/scratch/baz7dy/INTP/lunar_trails/sb3_dqn2.py", line 264, in collect_rollouts
[rank0]:     if not callback.on_step():
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/common/callbacks.py", line 114, in on_step
[rank0]:     return self._on_step()
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/common/callbacks.py", line 517, in _on_step
[rank0]:     self.model.save(os.path.join(self.best_model_save_path, "best_model"))
[rank0]:   File "/sfs/weka/scratch/baz7dy/INTP/lunar_trails/sb3_dqn2.py", line 305, in save
[rank0]:     super().save(path, exclude=exclude, include=include)
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/common/base_class.py", line 844, in save
[rank0]:     save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/common/save_util.py", line 316, in save_to_zip_file
[rank0]:     serialized_data = data_to_json(data)
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/stable_baselines3/common/save_util.py", line 100, in data_to_json
[rank0]:     base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
[rank0]:     cp.dump(obj)
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
[rank0]:     return super().dump(obj)
[rank0]:   File "/home/baz7dy/.conda/envs/ipenv1/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 577, in __reduce__
[rank0]:     raise RuntimeError("LLMEngine should not be pickled!")
[rank0]: RuntimeError: LLMEngine should not be pickled!

System Info

No response

Checklist

araffin commented 1 month ago

https://github.com/DLR-RM/stable-baselines3/blob/56c153f048f1035f239b77d1569b240ace83c130/stable_baselines3/dqn/dqn.py#L276

chrisgao99 commented 1 month ago

Thank you so much for the guidance. Now I know how to exclude the external LLM model while saving it.

For other people to understand,

I input a VLM model to the DQN class through policy_kwargs

model = CustomDQN(CustomDQNPolicy, env,verbose=1,tensorboard_log="logs/tb/llmguide_dqn",learning_starts=10000,
            policy_kwargs={"vlm_model":VLM,"vlm_model_name":vlm_model_name})

To avoid saving the VLM, I added the 'policy_kwargs' the _excluded_save_params function:

class CustomDQN(DQN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def _excluded_save_params(self):
        return super()._excluded_save_params() + ['policy_kwargs']

If you are not sure which part you want to exclude, you can also print all the things you will save before the "save_to_zip_file" here: https://github.com/DLR-RM/stable-baselines3/blob/56c153f048f1035f239b77d1569b240ace83c130/stable_baselines3/common/base_class.py#L866