Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/examples/dimensions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
"| line | x, color, line_dash, facet_col, facet_row, animation_frame |\n",
"| scatter | x, color, symbol, facet_col, facet_row, animation_frame |\n",
"| bar | x, color, facet_col, facet_row, animation_frame |\n",
"| imshow | x, y, facet_col, facet_row, animation_frame |"
"| imshow | y, x, facet_col, facet_row, animation_frame |"
]
},
{
Expand Down Expand Up @@ -186,6 +186,16 @@
"xpx(data_3d).line(x=\"year\", facet_col=\"metric\", facet_row=\"country\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Heatmap grids: imshow supports facet_col and facet_row (requires plotly>=6.7)\n",
"xpx(data_3d).imshow(facet_row=\"metric\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
104 changes: 103 additions & 1 deletion tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import xarray as xr

import xarray_plotly # noqa: F401 - registers accessor
from xarray_plotly import xpx
from xarray_plotly import plotting, xpx
from xarray_plotly.plotting import _imshow_supports_facet_row


class TestXpxFunction:
Expand Down Expand Up @@ -403,6 +404,107 @@ def test_imshow_animation_consistent_bounds(self) -> None:
assert coloraxis.cmax == 70.0


requires_imshow_facet_row = pytest.mark.skipif(
not _imshow_supports_facet_row(),
reason="facet_row in px.imshow requires plotly>=6.7.0",
)


class TestImshowFaceting:
"""Tests for imshow facet_col and facet_row."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
"""Set up test data."""
self.da_3d = xr.DataArray(
np.random.rand(4, 5, 3),
dims=["lat", "lon", "scenario"],
coords={"scenario": ["a", "b", "c"]},
name="temperature",
)
self.da_4d = xr.DataArray(
np.random.rand(4, 5, 2, 3),
dims=["lat", "lon", "scenario", "year"],
coords={"scenario": ["low", "high"], "year": [2020, 2021, 2022]},
name="temperature",
)

def test_imshow_facet_col(self) -> None:
"""Test imshow with facet_col creates one subplot per value."""
fig = self.da_3d.plotly.imshow()
assert len(fig.data) == 3
xaxes = [k for k in fig.layout if k.startswith("xaxis")]
assert len(xaxes) == 3

@requires_imshow_facet_row
def test_imshow_facet_row_explicit(self) -> None:
"""Test imshow with explicit facet_row creates one subplot row per value."""
fig = self.da_3d.plotly.imshow(facet_col=None, facet_row="scenario")
assert len(fig.data) == 3
yaxes = [k for k in fig.layout if k.startswith("yaxis")]
assert len(yaxes) == 3
annotations = {a.text for a in fig.layout.annotations}
assert annotations == {"scenario=a", "scenario=b", "scenario=c"}

@requires_imshow_facet_row
def test_imshow_facet_row_auto_4d(self) -> None:
"""Test that a 4D array auto-assigns facet_col and facet_row."""
fig = self.da_4d.plotly.imshow()
# 2 facet columns (scenario) x 3 facet rows (year)
assert len(fig.data) == 6
annotations = {a.text for a in fig.layout.annotations}
assert annotations == {
"scenario=low",
"scenario=high",
"year=2020",
"year=2021",
"year=2022",
}

@requires_imshow_facet_row
def test_imshow_facet_grid_consistent_bounds(self) -> None:
"""Test that facet grid subplots share global color bounds."""
da = xr.DataArray(
np.arange(24, dtype=float).reshape(2, 2, 2, 3),
dims=["y", "x", "scenario", "year"],
)
fig = da.plotly.imshow()
coloraxis = fig.layout.coloraxis
assert coloraxis.cmin == 0.0
assert coloraxis.cmax == 23.0

@requires_imshow_facet_row
def test_imshow_facet_grid_with_animation(self) -> None:
"""Test imshow with facet_col, facet_row, and animation_frame together."""
da = xr.DataArray(
np.random.rand(4, 5, 2, 3, 6),
dims=["lat", "lon", "scenario", "year", "time"],
name="temperature",
)
fig = da.plotly.imshow()
assert len(fig.data) == 6
assert len(fig.frames) == 6

def test_imshow_explicit_facet_row_unsupported_error(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test informative error when facet_row is requested on old plotly."""
monkeypatch.setattr(plotting, "_imshow_supports_facet_row", lambda: False)
with pytest.raises(ValueError, match=r"facet_row for imshow requires plotly>=6\.7\.0"):
self.da_3d.plotly.imshow(facet_col=None, facet_row="scenario")

def test_imshow_auto_skips_facet_row_on_old_plotly(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that auto-assignment skips facet_row on old plotly (4th dim animates)."""
monkeypatch.setattr(plotting, "_imshow_supports_facet_row", lambda: False)
fig = self.da_4d.plotly.imshow()
# year (4th dim) falls through to animation_frame instead of facet_row
assert len(fig.frames) == 3
# only the facet_col (scenario) produces subplots
assert len(fig.data) == 2


class TestColorsParameter:
"""Tests for the unified colors parameter."""

Expand Down
17 changes: 17 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ def test_auto_assignment_imshow(self) -> None:
slots = assign_slots(["lat", "lon"], "imshow")
assert slots == {"y": "lat", "x": "lon"}

def test_auto_assignment_imshow_4d(self) -> None:
"""Test that the fourth dimension fills facet_row for imshow."""
slots = assign_slots(["lat", "lon", "scenario", "year"], "imshow")
assert slots == {"y": "lat", "x": "lon", "facet_col": "scenario", "facet_row": "year"}

def test_auto_assignment_imshow_5d(self) -> None:
"""Test that the fifth dimension fills animation_frame for imshow."""
slots = assign_slots(["lat", "lon", "scenario", "year", "time"], "imshow")
assert slots == {
"y": "lat",
"x": "lon",
"facet_col": "scenario",
"facet_row": "year",
"animation_frame": "time",
}

def test_auto_assignment_scatter(self) -> None:
"""Test automatic positional assignment for scatter plots."""
slots = assign_slots(["x_dim", "color_dim"], "scatter")
Expand Down Expand Up @@ -124,5 +140,6 @@ def test_imshow_slot_order(self) -> None:
"y",
"x",
"facet_col",
"facet_row",
"animation_frame",
)
9 changes: 7 additions & 2 deletions xarray_plotly/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,15 @@ def imshow(
x: SlotValue = auto,
y: SlotValue = auto,
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
robust: bool = False,
colors: Colors = None,
**px_kwargs: Any,
) -> go.Figure:
"""Create an interactive heatmap image.

Slot order: y (rows) -> x (columns) -> facet_col -> animation_frame
Slot order: y (rows) -> x (columns) -> facet_col -> facet_row -> animation_frame

Note:
**Difference from px.imshow**: Color bounds are computed from the
Expand All @@ -320,7 +321,10 @@ def imshow(
x: Dimension for x-axis (columns). Default: second dimension.
y: Dimension for y-axis (rows). Default: first dimension.
facet_col: Dimension for subplot columns. Default: third dimension.
animation_frame: Dimension for animation. Default: fourth dimension.
facet_row: Dimension for subplot rows. Default: fourth dimension.
Requires plotly>=6.7.0; on older versions this slot is skipped
during auto-assignment.
animation_frame: Dimension for animation. Default: fifth dimension.
robust: If True, use 2nd/98th percentiles for color bounds (handles outliers).
colors: Color scale name (e.g., "Viridis", "RdBu"). See module docs.
**px_kwargs: Additional arguments passed to `plotly.express.imshow()`.
Expand All @@ -334,6 +338,7 @@ def imshow(
x=x,
y=y,
facet_col=facet_col,
facet_row=facet_row,
animation_frame=animation_frame,
robust=robust,
colors=colors,
Expand Down
2 changes: 1 addition & 1 deletion xarray_plotly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"facet_row",
"animation_frame",
),
"imshow": ("y", "x", "facet_col", "animation_frame"),
"imshow": ("y", "x", "facet_col", "facet_row", "animation_frame"),
"box": ("x", "color", "facet_col", "facet_row", "animation_frame"),
"pie": ("names", "facet_col", "facet_row"),
}
Expand Down
43 changes: 39 additions & 4 deletions xarray_plotly/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import inspect
import warnings
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -614,12 +615,21 @@ def scatter(
)


def _imshow_supports_facet_row() -> bool:
"""Check whether the installed plotly version supports facet_row in px.imshow.

Support was added in plotly 6.7.0.
"""
return "facet_row" in inspect.signature(px.imshow).parameters


def imshow(
darray: DataArray,
*,
x: SlotValue = auto,
y: SlotValue = auto,
facet_col: SlotValue = auto,
facet_row: SlotValue = auto,
animation_frame: SlotValue = auto,
robust: bool = False,
colors: Colors = None,
Expand All @@ -629,7 +639,7 @@ def imshow(
Create an interactive heatmap from a DataArray.

Both x and y are dimensions. Dimensions fill slots in order:
y (rows) -> x (columns) -> facet_col -> animation_frame
y (rows) -> x (columns) -> facet_col -> facet_row -> animation_frame

.. note::
**Difference from plotly.express.imshow**: By default, color bounds
Expand All @@ -649,8 +659,14 @@ def imshow(
Dimension for y-axis (rows). Default: first dimension.
facet_col
Dimension for subplot columns. Default: third dimension.
facet_row
Dimension for subplot rows. Default: fourth dimension.
Requires plotly>=6.7.0; on older versions this slot is skipped
during auto-assignment (the fourth dimension animates instead).
Note: ``facet_col_wrap`` is ignored by plotly when ``facet_row``
is set.
animation_frame
Dimension for animation. Default: fourth dimension.
Dimension for animation. Default: fifth dimension.
robust
If True, compute color bounds using 2nd and 98th percentiles
for robustness against outliers. Default: False (uses min/max).
Expand All @@ -668,18 +684,36 @@ def imshow(
plotly.graph_objects.Figure
"""
px_kwargs = resolve_colors(colors, px_kwargs)

# On plotly < 6.7.0, px.imshow has no facet_row: skip auto-assignment so
# dimensions fall through to animation_frame instead.
if facet_row is auto and not _imshow_supports_facet_row():
facet_row = None

slots = assign_slots(
list(darray.dims),
"imshow",
y=y,
x=x,
facet_col=facet_col,
facet_row=facet_row,
animation_frame=animation_frame,
)

# Transpose to: y (rows), x (cols), facet_col, animation_frame
facet_row_kwargs: dict[str, Any] = {}
if slots.get("facet_row") is not None:
if not _imshow_supports_facet_row():
import plotly

msg = f"facet_row for imshow requires plotly>=6.7.0 (installed: {plotly.__version__})."
raise ValueError(msg)
facet_row_kwargs["facet_row"] = slots["facet_row"]

# Transpose to: y (rows), x (cols), facet_col, facet_row, animation_frame
transpose_order = [
slots[k] for k in ("y", "x", "facet_col", "animation_frame") if slots.get(k) is not None
slots[k]
for k in ("y", "x", "facet_col", "facet_row", "animation_frame")
if slots.get(k) is not None
]
plot_data = darray.transpose(*transpose_order) if transpose_order else darray

Expand All @@ -701,6 +735,7 @@ def imshow(
plot_data,
facet_col=slots.get("facet_col"),
animation_frame=slots.get("animation_frame"),
**facet_row_kwargs,
**px_kwargs,
)

Expand Down