Chan-Dong-Jun / mesa-forked

Mesa is an open-source Python library for agent-based modeling, ideal for simulating complex systems and exploring emergent behaviors.
https://mesa.readthedocs.io
Apache License 2.0
1 stars 0 forks source link

Cache grid pos #4

Open Chan-Dong-Jun opened 2 months ago

Chan-Dong-Jun commented 2 months ago

What's the problem this feature will solve? There should be a method to cache the positions of agents at each step of the simulation so that we can display the agents on the grid visualization.

Describe the solution you'd like

  1. Cache the positions of the agents
  2. Modify the visualization module. Currently, the visualization module requires the model object to get the grid attribute for displaying visualization. The module should be modified such that there is no need for the model object, and the grid is directly constructed from the cached file.
Chan-Dong-Jun commented 2 months ago

Main edit

Caching methods

    def get_grid_dataframe(self, cache_file_path: str = None):
        grid_state = {
            'width': self.model.grid.width,
            'height': self.model.grid.height,
            'agents': []
        }
        for x in range(grid_state['width']):
            for y in range(grid_state['height']):
                cell_contents = self.model.grid._grid[x][y]
                if cell_contents:
                    if not hasattr(cell_contents, "__iter__"):
                        cell_contents = [cell_contents]
                    for agent in cell_contents:
                        agent_state = {
                            'pos_x': agent.pos[0],
                            'pos_y': agent.pos[1],
                            'unique_id': agent.unique_id,
                            'wealth': agent.wealth,
                            # **agent.__dict__
                        }
                        grid_state['agents'].append(agent_state)
        padding = len(str(self._total_steps)) - 1
        filename = f"{self.cache_file_path}/grid_data_{(self.model._steps):0{padding}}.parquet"

        # Convert to DataFrame
        df = pd.DataFrame(grid_state['agents'])

        # Save DataFrame to Parquet
        df.to_parquet(filename)

    @staticmethod
    def reconstruct_grid(filename, *attributes_list):
        # Load the DataFrame from Parquet
        df = pd.read_parquet(filename)

        # Create a new Grid instance
        width = df['pos_x'].max() + 1  # Assuming positions start from 0
        height = df['pos_y'].max() + 1  # Assuming positions start from 0
        grid = Grid(width, height, False)

        # Add agents to the grid
        for _, row in df.iterrows():
            agent = Agent(row['unique_id'], Model(100, 10, 10))
            agent.wealth = row["wealth"]
            grid.place_agent(agent, (row['pos_x'], row['pos_y']))

        return grid

get_grid_dataframe caches the position of the agents and writes them to a parquet file. reconstruct_grid takes in the parquet file and returns a grid object with the Agents populated.

Visualisation module

@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None):
    space_fig = Figure()
    space_ax = space_fig.subplots()

    space = CacheableModel.reconstruct_grid(f'output_dir/grid_data_{(model._steps + 1):0{3}}.parquet')
    print(f'output_dir/grid_data_{(model._steps + 1):0{3}}.parquet')
    if space is None:
        # Sometimes the space is defined as model.space instead of model.grid
        space = model.space
    if isinstance(space, mesa.space.NetworkGrid):
        _draw_network_grid(space, space_ax, agent_portrayal)
    elif isinstance(space, mesa.space.ContinuousSpace):
        _draw_continuous_space(space, space_ax, agent_portrayal)
    else:
        _draw_grid(space, space_ax, agent_portrayal)
    solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)

The SpaceMatplotlib takes a parquet file and reconstructs the grid object. This is in contrast to directly reading from grid attribute of the model object. This is still a work in progress but currently the grid visualisation can read from the cached files directly. image

Chan-Dong-Jun commented 2 months ago

Main edit

Caching methods

    class TestModel(mesa.Model):
    """A model with some number of agents."""

    def __init__(self, N, width, height):
        super().__init__()
        # self.num_agents = N
        # self.grid = mesa.space.MultiGrid(1,1, True)
        # self.schedule = mesa.time.RandomActivation(self)

        # Create agents
        # for i in range(self.num_agents):
            # a = MoneyAgent(i, self)
            # self.schedule.add(a)
            # Add the agent to a random grid cell
            # x = self.random.randrange(self.grid.width)
            # y = self.random.randrange(self.grid.height)
            # self.grid.place_agent(a, (x, y))

        # self.datacollector = mesa.DataCollector()

    def step(self):
        # self.datacollector.collect(self)
        self._steps += 1

The TestModel acts as a dummy model to be fed into SolaraViz. This allows the visualizer to step while the data displayed is taken from the cached data.

Limitations:

Chan-Dong-Jun commented 2 months ago

Main edit

Caching methods

   @solara.component
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
    fig = Figure()
    ax = fig.subplots()

    # TODO: Check
    model_files = glob.glob(f"output_dir/model_data_*.parquet")
    model_dfs = []
    for model_file in model_files:
        table = pq.read_table(model_file)
        df = table.to_pandas()
        model_dfs.append(df)
    df = pd.concat(model_dfs, ignore_index=True)[:model._steps+1]

    if isinstance(measure, str):
        ax.plot(df.loc[:, measure])
        ax.set_ylabel(measure)
    elif isinstance(measure, dict):
        for m, color in measure.items():
            ax.plot(df.loc[:, m], label=m, color=color)
        fig.legend()
    elif isinstance(measure, list | tuple):
        for m in measure:
            ax.plot(df.loc[:, m], label=m)
        fig.legend()
    # Set integer x axis
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    solara.FigureMatplotlib(fig, dependencies=dependencies)

The PlotMatplotlib will now plot the matplotlib graph from cached data.

image