Open Chan-Dong-Jun opened 2 months ago
Caching methods
def get_grid_dataframe(self, cache_file_path: str = None):
grid_state = {
'width': self.model.grid.width,
'height': self.model.grid.height,
'agents': []
}
for x in range(grid_state['width']):
for y in range(grid_state['height']):
cell_contents = self.model.grid._grid[x][y]
if cell_contents:
if not hasattr(cell_contents, "__iter__"):
cell_contents = [cell_contents]
for agent in cell_contents:
agent_state = {
'pos_x': agent.pos[0],
'pos_y': agent.pos[1],
'unique_id': agent.unique_id,
'wealth': agent.wealth,
# **agent.__dict__
}
grid_state['agents'].append(agent_state)
padding = len(str(self._total_steps)) - 1
filename = f"{self.cache_file_path}/grid_data_{(self.model._steps):0{padding}}.parquet"
# Convert to DataFrame
df = pd.DataFrame(grid_state['agents'])
# Save DataFrame to Parquet
df.to_parquet(filename)
@staticmethod
def reconstruct_grid(filename, *attributes_list):
# Load the DataFrame from Parquet
df = pd.read_parquet(filename)
# Create a new Grid instance
width = df['pos_x'].max() + 1 # Assuming positions start from 0
height = df['pos_y'].max() + 1 # Assuming positions start from 0
grid = Grid(width, height, False)
# Add agents to the grid
for _, row in df.iterrows():
agent = Agent(row['unique_id'], Model(100, 10, 10))
agent.wealth = row["wealth"]
grid.place_agent(agent, (row['pos_x'], row['pos_y']))
return grid
get_grid_dataframe
caches the position of the agents and writes them to a parquet file. reconstruct_grid
takes in the parquet file and returns a grid object with the Agents populated.
Visualisation module
@solara.component
def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None):
space_fig = Figure()
space_ax = space_fig.subplots()
space = CacheableModel.reconstruct_grid(f'output_dir/grid_data_{(model._steps + 1):0{3}}.parquet')
print(f'output_dir/grid_data_{(model._steps + 1):0{3}}.parquet')
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
if isinstance(space, mesa.space.NetworkGrid):
_draw_network_grid(space, space_ax, agent_portrayal)
elif isinstance(space, mesa.space.ContinuousSpace):
_draw_continuous_space(space, space_ax, agent_portrayal)
else:
_draw_grid(space, space_ax, agent_portrayal)
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)
The SpaceMatplotlib takes a parquet file and reconstructs the grid object. This is in contrast to directly reading from grid attribute of the model object. This is still a work in progress but currently the grid visualisation can read from the cached files directly.
Caching methods
class TestModel(mesa.Model):
"""A model with some number of agents."""
def __init__(self, N, width, height):
super().__init__()
# self.num_agents = N
# self.grid = mesa.space.MultiGrid(1,1, True)
# self.schedule = mesa.time.RandomActivation(self)
# Create agents
# for i in range(self.num_agents):
# a = MoneyAgent(i, self)
# self.schedule.add(a)
# Add the agent to a random grid cell
# x = self.random.randrange(self.grid.width)
# y = self.random.randrange(self.grid.height)
# self.grid.place_agent(a, (x, y))
# self.datacollector = mesa.DataCollector()
def step(self):
# self.datacollector.collect(self)
self._steps += 1
The TestModel acts as a dummy model to be fed into SolaraViz. This allows the visualizer to step while the data displayed is taken from the cached data.
Limitations:
Caching methods
@solara.component
def PlotMatplotlib(model, measure, dependencies: list[any] | None = None):
fig = Figure()
ax = fig.subplots()
# TODO: Check
model_files = glob.glob(f"output_dir/model_data_*.parquet")
model_dfs = []
for model_file in model_files:
table = pq.read_table(model_file)
df = table.to_pandas()
model_dfs.append(df)
df = pd.concat(model_dfs, ignore_index=True)[:model._steps+1]
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
fig.legend()
elif isinstance(measure, list | tuple):
for m in measure:
ax.plot(df.loc[:, m], label=m)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
solara.FigureMatplotlib(fig, dependencies=dependencies)
The PlotMatplotlib will now plot the matplotlib graph from cached data.
What's the problem this feature will solve? There should be a method to cache the positions of agents at each step of the simulation so that we can display the agents on the grid visualization.
Describe the solution you'd like