-
Notifications
You must be signed in to change notification settings - Fork 874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Altair Grid #1902
base: main
Are you sure you want to change the base?
Introduce Altair Grid #1902
Changes from 3 commits
a7a6c4b
6236e64
7707d5b
0b3b091
300e48c
9b1ba7f
9dbb494
9e0c00e
a0ce9d9
593a9f3
5182198
3fe1838
aa52006
487dca3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import json | ||
from typing import Callable | ||
|
||
import altair as alt | ||
import solara | ||
|
||
import mesa | ||
|
||
|
||
def get_agent_data_from_coord_iter(data): | ||
for agent, (x, y) in data: | ||
if agent: | ||
agent_data = json.loads( | ||
json.dumps(agent[0].__dict__, skipkeys=True, default=str) | ||
) | ||
agent_data["x"] = x | ||
agent_data["y"] = y | ||
agent_data.pop("model", None) | ||
agent_data.pop("pos", None) | ||
yield agent_data | ||
|
||
|
||
def create_grid( | ||
color: str | None = None, | ||
on_click: Callable[[mesa.Model, mesa.space.Coordinate], None] | None = None, | ||
) -> Callable[[mesa.Model], solara.component]: | ||
return lambda model: Grid(model, color, on_click) | ||
ankitk50 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def Grid(model, color=None, on_click=None): | ||
if color is None: | ||
color = "unique_id:N" | ||
|
||
if color[-2] != ":": | ||
color = color + ":N" | ||
|
||
print(model.grid.coord_iter()) | ||
|
||
data = solara.reactive( | ||
list(get_agent_data_from_coord_iter(model.grid.coord_iter())) | ||
) | ||
|
||
def update_data(): | ||
data.value = list(get_agent_data_from_coord_iter(model.grid.coord_iter())) | ||
|
||
def click_handler(datum): | ||
if datum is None: | ||
return | ||
on_click(model, datum["x"], datum["y"]) | ||
update_data() | ||
|
||
default_tooltip = [f"{key}:N" for key in data.value[0]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice that you have tooltip and click handlers - a little hard to assess without documentation, I would want to know how I can customize these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes please add docstrings to this explaining how to customize the tooltip |
||
chart = ( | ||
alt.Chart(alt.Data(values=data.value)) | ||
.mark_rect() | ||
.encode( | ||
x=alt.X("x:N", scale=alt.Scale(domain=list(range(model.grid.width)))), | ||
y=alt.Y( | ||
"y:N", | ||
scale=alt.Scale(domain=list(range(model.grid.height - 1, -1, -1))), | ||
), | ||
color=color, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In one of my models where I'm using a custom altair space drawer, I'm setting color, size, and shape. Probably reasonable not to support all of those on the first pass, but it would be good to think about a more generalized approach (like the agent portrayal method) that would make it possible to customize this without having to completely re-implement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be pretty straightforward to do with the |
||
tooltip=default_tooltip, | ||
) | ||
.properties(width=600, height=600) | ||
) | ||
return solara.FigureAltair(chart, on_click=click_handler) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
from solara.alias import rv | ||
|
||
import mesa | ||
from mesa.experimental.altair_grid import create_grid | ||
|
||
# Avoid interactive backend | ||
plt.switch_backend("agg") | ||
|
@@ -75,6 +76,12 @@ def ColorCard(color, layout_type): | |
SpaceMatplotlib( | ||
model, agent_portrayal, dependencies=[current_step.value] | ||
) | ||
elif space_drawer == "altair": | ||
# draw with the altair implementation | ||
SpaceAltair( | ||
model, agent_portrayal, dependencies=[current_step.value] | ||
) | ||
|
||
elif space_drawer: | ||
# if specified, draw agent space with an alternate renderer | ||
space_drawer(model, agent_portrayal) | ||
|
@@ -109,6 +116,9 @@ def render_in_jupyter(): | |
SpaceMatplotlib( | ||
model, agent_portrayal, dependencies=[current_step.value] | ||
) | ||
elif space_drawer == "altair": | ||
# draw with the default implementation | ||
SpaceAltair(model, agent_portrayal, dependencies=[current_step.value]) | ||
elif space_drawer: | ||
# if specified, draw agent space with an alternate renderer | ||
space_drawer(model, agent_portrayal) | ||
|
@@ -139,6 +149,12 @@ def render_in_browser(): | |
ModelController(model, play_interval, current_step, reset_counter) | ||
with solara.Card("Progress", margin=1, elevation=2): | ||
solara.Markdown(md_text=f"####Step - {current_step}") | ||
with solara.Card("Analytics", margin=1, elevation=2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to the PR. |
||
df = model.datacollector.get_model_vars_dataframe() | ||
for col in list(df.columns): | ||
solara.Markdown( | ||
md_text=f"####Avg. {col} - {df.loc[:, f'{col}'].mean()}" | ||
) | ||
|
||
items = [ | ||
ColorCard(color="white", layout_type=layout_types[i]) | ||
|
@@ -334,6 +350,12 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: Optional[List[any]] = | |
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) | ||
|
||
|
||
@solara.component | ||
def SpaceAltair(model, agent_portrayal, dependencies: Optional[List[any]] = None): | ||
grid = create_grid(color="wealth") | ||
grid(model) | ||
|
||
|
||
def _draw_grid(space, space_ax, agent_portrayal): | ||
def portray(g): | ||
x = [] | ||
|
@@ -424,7 +446,7 @@ def get_initial_grid_layout(layout_types): | |
grid_lay = [] | ||
y_coord = 0 | ||
for ii in range(len(layout_types)): | ||
template_layout = {"h": 10, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0} | ||
template_layout = {"h": 20, "i": 0, "moved": False, "w": 6, "y": 0, "x": 0} | ||
if ii == 0: | ||
grid_lay.append(template_layout) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,7 @@ dependencies = [ | |
"pandas", | ||
"solara", | ||
"tqdm", | ||
"altair" | ||
] | ||
dynamic = ["version"] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my current implementation, I'm using an
agent_portrayal
method to generate the values needed to draw the space. That may be drawn from the old mesa visualization approach, IDK if there's a good way to pass in something like that to jupyterviz.I mention because I wonder if it would be cleaner and more explicit than the way you're using json to dump and filter the agent dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @rlskoeser; the JSON approach is not very clean. Is there a more explicit way to do this?