LucasAlegre / sumo-rl

Reinforcement Learning environments for Traffic Signal Control with SUMO. Compatible with Gymnasium, PettingZoo, and popular RL libraries.
https://lucasalegre.github.io/sumo-rl
MIT License
703 stars 192 forks source link

Cannot load DQN model #195

Closed fqidz closed 6 months ago

fqidz commented 6 months ago

First of all: Sorry if this doesn't belong here. I'll post this on the stable-baselines3 github if so.

Hello I'm a beginner and I'm facing this problem where I cant load the saved DQN model. I trained it using libsumo and I want to load the saved model so that I can see its performance using sumo-gui. I am also not sure if this is an issue with the environment or with stable-baselines3.

When I try to load the model using model = DQN.load('./output/model_saved.zip', env=env), I get the error:

/home/<name>/Documents/sumo-traffic-capstone/utils/stable-baselines3/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument
 to replace this object.
Exception: code() argument 13 must be str, not int
  warnings.warn(
/home/<name>/Documents/sumo-traffic-capstone/utils/stable-baselines3/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects`
 argument to replace this object.
Exception: code() argument 13 must be str, not int
  warnings.warn(
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
 Retrying in 1 seconds
 ....

Here's the print_system_info if it's relevant:

== CURRENT SYSTEM INFO ==
- OS: Linux-6.7.10-200.fc39.x86_64-x86_64-with-glibc2.38 # 1 SMP PREEMPT_DYNAMIC Mon Mar 18 18:56:52 UTC 2024
- Python: 3.12.2
- Stable-Baselines3: 2.3.0a5
- PyTorch: 2.2.1+cu121
- GPU Enabled: False
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1

== SAVED MODEL SYSTEM INFO ==
- OS: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 # 1 SMP Thu Oct 5 21:02:42 UTC 2023
- Python: 3.10.12
- Stable-Baselines3: 2.3.0a5
- PyTorch: 2.2.1+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1

I tried adding custom_object, model = DQN.load('./output/model_saved.zip', env=env, custom_objects={'lr_schedule': 0.0, 'exploration_schedule': 0.0}), as per the example here, but it just removed the warning but it still didn't load the save model and started learning from the beginning.

I'm using this as the env:

env = SumoEnvironment(net_file='./sumo-things/net.net.xml',
                      route_file='./sumo-things/main.rou.xml',
                      out_csv_name='./output/dqn-stats/traffic_sim',
                      reward_fn=my_reward_fn,
                      yellow_time=4,
                      time_to_teleport=2000,
                      use_gui=use_gui,
                      single_agent=True,
                      num_seconds=num_seconds,
                      )

Also another thing to note is I cloned this repo and did pip install -e . because I wanted to edit one part of the env.py to output the queue length in the csv file, but I'm not sure if it's relevant:

    def _get_per_agent_info(self):
        stopped = [self.traffic_signals[ts].get_total_queued()
                   for ts in self.ts_ids]
        accumulated_waiting_time = [
            sum(self.traffic_signals[ts].get_accumulated_waiting_time_per_lane()) for ts in self.ts_ids
        ]
        average_speed = [self.traffic_signals[ts].get_average_speed()
                         for ts in self.ts_ids]
+       total_queued = [self.traffic_signals[ts].get_total_queued()
+                       for ts in self.ts_ids]
        info = {}
        for i, ts in enumerate(self.ts_ids):
            info[f"{ts}_stopped"] = stopped[i]
            info[f"{ts}_accumulated_waiting_time"] = accumulated_waiting_time[i]
            info[f"{ts}_average_speed"] = average_speed[i]
+           info[f"{ts}_queue_length"] = total_queued[i]
        info["agents_total_stopped"] = sum(stopped)
        info["agents_total_accumulated_waiting_time"] = sum(
            accumulated_waiting_time)
        return info
fqidz commented 6 months ago

Solved this with the help of this reply

I added model.set_env(env) after model.load(), and reset_num_timesteps=False to model.learn()