Closed TGW795 closed 1 year ago
Hi,
Can you share your code?
Sure. This is my code and changes. (I've been checking the behavior of several observation values, and I've not written about iteration part.)
import numpy as np
from gymnasium import spaces
from .traffic_signal import TrafficSignal
class ObservationFunction: """Abstract base class for observation functions."""
def __init__(self, ts: TrafficSignal):
"""Initialize observation function."""
self.ts = ts
@abstractmethod
def __call__(self):
"""Subclasses must override this method."""
phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)]
observation = np.array(phase_id, dtype=np.float32)
return observation
@abstractmethod
def observation_space(self):
"""Subclasses must override this method."""
return spaces.Box(
low=np.zeros(self.ts.num_green_phases, dtype=np.float32),
high=np.ones(self.ts.num_green_phases, dtype=np.float32),
class DefaultObservationFunction(ObservationFunction): """Default observation function for traffic signals."""
def __init__(self, ts: TrafficSignal):
"""Initialize default observation function."""
super().__init__(ts)
def __call__(self) -> np.ndarray:
"""Return the default observation."""
phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)] # one-hot encoding
min_green = [0 if self.ts.time_since_last_phase_change < self.ts.min_green + self.ts.yellow_time else 1]
density = self.ts.get_lanes_density()
queue = self.ts.get_lanes_queue()
observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
return observation
def observation_space(self) -> spaces.Box:
"""Return the observation space."""
return spaces.Box(
low=np.zeros(self.ts.num_green_phases + 1 + 2 * len(self.ts.lanes), dtype=np.float32),
high=np.ones(self.ts.num_green_phases + 1 + 2 * len(self.ts.lanes), dtype=np.float32),
)
- sumo_rl/environment/env.py(#L99)
observation_class: ObservationFunction = ObservationFunction,
I thought that we would obtain only phase_id by these modifications, but in fact, I got values defined in DefaultObservationFunction. (I've checked this issue by running a code same as an example of PettingZoo Multi-Agent API.)
ObservationFunction
is an abstract class, you should not modify it. You have to create a new class that implements the abstract methods:
class MyObservationFunction(ObservationFunction):
def __init__(self, ts: TrafficSignal):
self.ts = ts
def __call__(self):
phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)]
observation = np.array(phase_id, dtype=np.float32)
return observation
def observation_space(self):
return spaces.Box(
low=np.zeros(self.ts.num_green_phases, dtype=np.float32),
high=np.ones(self.ts.num_green_phases, dtype=np.float32),
# In your experiment file:
env = sumo_rl.env(..., observation_class=MyObservationFunction)
It worked! Thank you! This was just an elementary mistake on my part :)
Hi.
I've been trying to writing a code to experiment with MARL using
sumo-rl.parallel_env()
and making my own ObservationFunction, but any changes I made do not apply. It was stated in README that we can use our original ObservationFunction by defining it in observations.py and passing it to the environment constructor, and I followed this flow. Is there anything I am doing wrong? (I'm using the example code of PettingZoo Multi-Agent API)Thank you.