-
Notifications
You must be signed in to change notification settings - Fork 929
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement VideoViz to record model runs in a video
- Loading branch information
Showing
3 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Example of using VideoViz with the Schelling model.""" | ||
|
||
from mesa.examples.basic.schelling.model import Schelling | ||
from mesa.visualization.video_viz import ( | ||
VideoViz, | ||
make_measure_component, | ||
make_space_component, | ||
) | ||
|
||
# Create model | ||
model = Schelling(10, 10) | ||
|
||
|
||
def agent_portrayal(agent): | ||
"""Portray agents based on their type.""" | ||
if agent is None: | ||
return {} | ||
|
||
portrayal = { | ||
"color": "red" if agent.type == 0 else "blue", | ||
"size": 25, | ||
"marker": "s", # square marker | ||
} | ||
return portrayal | ||
|
||
|
||
# Create visualization with space and some metrics | ||
viz = VideoViz( | ||
model, | ||
[ | ||
make_space_component(agent_portrayal=agent_portrayal, save_format="svg"), | ||
make_measure_component("happy", save_format="svg"), | ||
], | ||
title="Schelling's Segregation Model", | ||
) | ||
|
||
# Record simulation | ||
if __name__ == "__main__": | ||
video_path = viz.record(steps=50, filepath="schelling.mp4") | ||
print(f"Video saved to: {video_path}") | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
"""Video recording components for Mesa model visualization.""" | ||
|
||
import shutil | ||
from collections.abc import Callable, Sequence | ||
from pathlib import Path | ||
|
||
import matplotlib.animation as animation | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import mesa | ||
from mesa.visualization.matplotlib_renderer import ( | ||
MatplotlibRenderer, | ||
MeasureRendererMatplotlib, | ||
SpaceRenderMatplotlib, | ||
) | ||
|
||
|
||
def make_space_component( | ||
agent_portrayal: Callable | None = None, | ||
propertylayer_portrayal: dict | None = None, | ||
post_process: Callable | None = None, | ||
**space_drawing_kwargs, | ||
): | ||
"""Create a Matplotlib-based space visualization component. | ||
Args: | ||
agent_portrayal: Function to portray agents. | ||
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications | ||
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks) | ||
backend: The backend to use for rendering the space. Can be "matplotlib" or "altair". | ||
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See | ||
the functions for drawing the various spaces for further details. | ||
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", | ||
"size", "marker", and "zorder". Other field are ignored and will result in a user warning. | ||
Returns: | ||
SpaceRenderMatplotlib: A component for rendering the space. | ||
""" | ||
if agent_portrayal is None: | ||
|
||
def agent_portrayal(a): | ||
return {} | ||
|
||
return SpaceRenderMatplotlib( | ||
agent_portrayal, | ||
propertylayer_portrayal, | ||
post_process=post_process, | ||
**space_drawing_kwargs, | ||
) | ||
|
||
|
||
def make_measure_component( | ||
measure: Callable, | ||
**kwargs, | ||
): | ||
"""Create a plotting function for a specified measure. | ||
Args: | ||
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. | ||
kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor. | ||
Returns: | ||
MeasureRendererMatplotlib: A component for rendering the measure. | ||
""" | ||
return MeasureRendererMatplotlib( | ||
measure, | ||
**kwargs, | ||
) | ||
|
||
|
||
class VideoViz: | ||
"""Create high-quality video recordings of model simulations.""" | ||
|
||
def __init__( | ||
self, | ||
model: mesa.Model, | ||
components: Sequence[MatplotlibRenderer], | ||
*, | ||
title: str | None = None, | ||
figsize: tuple[float, float] | None = None, | ||
grid: tuple[int, int] | None = None, | ||
): | ||
"""Initialize video visualization configuration. | ||
Args: | ||
model: The model to simulate and record | ||
components: Sequence of component objects defining what to visualize | ||
title: Optional title for the video | ||
figsize: Optional figure size in inches (width, height) | ||
grid: Optional (rows, cols) for custom layout. Auto-calculated if None. | ||
""" | ||
# Check if FFmpeg is available | ||
if not shutil.which("ffmpeg"): | ||
raise RuntimeError( | ||
"FFmpeg not found. Please install FFmpeg to save animations:\n" | ||
" - macOS: brew install ffmpeg\n" | ||
" - Linux: sudo apt-get install ffmpeg\n" | ||
" - Windows: download from https://ffmpeg.org/download.html" | ||
) | ||
self.model = model | ||
self.components = components | ||
self.title = title | ||
self.figsize = figsize | ||
self.grid = grid or self._calculate_grid(len(components)) | ||
|
||
# Setup figure and axes | ||
self.fig, self.axes = self._setup_figure() | ||
|
||
def record( | ||
self, | ||
*, | ||
steps: int, | ||
filepath: str | Path, | ||
dpi: int = 100, | ||
fps: int = 10, | ||
codec: str = "h264", | ||
bitrate: int = 2000, | ||
) -> Path: | ||
"""Record model simulation to video file. | ||
Args: | ||
steps: Number of simulation steps to record | ||
filepath: Where to save the video file | ||
dpi: Resolution of the output video | ||
fps: Frames per second in the output video | ||
codec: Video codec to use | ||
bitrate: Video bitrate in kbps (default: 2000) | ||
Returns: | ||
Path to the saved video file | ||
Raises: | ||
RuntimeError: If FFmpeg is not installed | ||
""" | ||
filepath = Path(filepath) | ||
|
||
def update(frame_num): | ||
# Update model state | ||
self.model.step() | ||
|
||
# Render all visualization frames | ||
for component, ax in zip(self.components, self.axes): | ||
ax.clear() | ||
component.draw(self.model, ax) | ||
return self.axes | ||
|
||
# Create and save animation | ||
anim = animation.FuncAnimation( | ||
self.fig, update, frames=steps, interval=1000 / fps, blit=False | ||
) | ||
|
||
writer = animation.FFMpegWriter( | ||
fps=fps, | ||
codec=codec, | ||
bitrate=bitrate, # Now passing as integer | ||
) | ||
|
||
anim.save(filepath, writer=writer, dpi=dpi) | ||
return filepath | ||
|
||
def _calculate_grid(self, n_frames: int) -> tuple[int, int]: | ||
"""Calculate optimal grid layout for given number of frames.""" | ||
cols = min(3, n_frames) # Max 3 columns | ||
rows = int(np.ceil(n_frames / cols)) | ||
return (rows, cols) | ||
|
||
def _setup_figure(self): | ||
"""Setup matplotlib figure and axes.""" | ||
if not self.figsize: | ||
self.figsize = (5 * self.grid[1], 5 * self.grid[0]) | ||
fig = plt.figure(figsize=self.figsize) | ||
axes = [] | ||
|
||
for i in range(len(self.components)): | ||
ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1) | ||
axes.append(ax) | ||
|
||
if self.title: | ||
fig.suptitle(self.title, fontsize=16) | ||
fig.tight_layout() | ||
return fig, axes |